• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 16502811447

24 Jul 2025 04:44PM UTC coverage: 97.767% (+0.1%) from 97.651%
16502811447

push

github

web-flow
New version 1.5.0

This pull request introduces version 1.5.0 of **nnodely**, featuring several updates:
1. Improved clarity of documentation and examples.
2. Support for managing multi-dataset features is now available.
3. DataFrames can now be used to create datasets.
4. Datasets can now be resampled.
5. Random data training has been fixed for both classic and recurrent training.
6. The `state` variable has been removed.
7. It is now possible to add or remove a connection or a closed loop.
8. Partial models can now be exported.
9. The `train` function and the result analysis have been separated.
10. A new function, `trainAndAnalyse`, is now available.
11. The report now works across all network types.
12. The training function code has been reorganized.

2901 of 2967 new or added lines in 53 files covered. (97.78%)

16 existing lines in 6 files now uncovered.

12652 of 12941 relevant lines covered (97.77%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.49
/nnodely/support/jsonutils.py
1
import copy
1✔
2
from pprint import pformat
1✔
3

4

5
from nnodely.support.utils import check
1✔
6

7
from nnodely.support.logger import logging, nnLogger
1✔
8
log = nnLogger(__name__, logging.CRITICAL)
1✔
9

10
def get_window(obj):
1✔
11
    return 'tw' if 'tw' in obj.dim else ('sw' if 'sw' in obj.dim else None)
1✔
12

13
# Codice per comprimere le relazioni
14
        #print(self.json['Relations'])
15
        # used_rel = {string for values in self.json['Relations'].values() for string in values[1]}
16
        # if obj1.name not in used_rel and obj1.name in self.json['Relations'].keys() and self.json['Relations'][obj1.name][0] == add_relation_name:
17
        #     self.json['Relations'][self.name] = [add_relation_name, self.json['Relations'][obj1.name][1]+[obj2.name]]
18
        #     del self.json['Relations'][obj1.name]
19
        # else:
20
        # Devo aggiungere un operazione che rimuove un operazione di Add,Sub,Mul,Div se può essere unita ad un'altra operazione dello stesso tipo
21
        #
22
def merge(source, destination, main = True):
1✔
23
    if main:
1✔
24
        for key, value in destination["Functions"].items():
1✔
25
            if key in source["Functions"].keys() and 'n_input' in value.keys() and 'n_input' in source["Functions"][key].keys():
1✔
26
                check(value == {} or source["Functions"][key] == {} or value['n_input'] == source["Functions"][key]['n_input'],
1✔
27
                      TypeError,
28
                      f"The ParamFun {key} is present multiple times, with different number of inputs. "
29
                      f"The ParamFun {key} is called with {value['n_input']} parameters and with {source['Functions'][key]['n_input']} parameters.")
30
        for key, value in destination["Parameters"].items():
1✔
31
            if key in source["Parameters"].keys():
1✔
32
                if 'dim' in value.keys() and 'dim' in source["Parameters"][key].keys():
1✔
33
                    check(value['dim'] == source["Parameters"][key]['dim'],
1✔
34
                          TypeError,
35
                          f"The Parameter {key} is present multiple times, with different dimensions. "
36
                          f"The Parameter {key} is called with {value['dim']} dimension and with {source['Parameters'][key]['dim']} dimension.")
37
                window_dest = 'tw' if 'tw' in value else ('sw' if 'sw' in value else None)
1✔
38
                window_source = 'tw' if 'tw' in source["Parameters"][key] else ('sw' if 'sw' in source["Parameters"][key] else None)
1✔
39
                if window_dest is not None:
1✔
40
                    check(window_dest == window_source and value[window_dest] == source["Parameters"][key][window_source] ,
1✔
41
                          TypeError,
42
                          f"The Parameter {key} is present multiple times, with different window. "
43
                          f"The Parameter {key} is called with {window_dest}={value[window_dest]} dimension and with {window_source}={source['Parameters'][key][window_source]} dimension.")
44

45
        log.debug("Merge Source")
1✔
46
        log.debug("\n"+pformat(source))
1✔
47
        log.debug("Merge Destination")
1✔
48
        log.debug("\n"+pformat(destination))
1✔
49
        result = copy.deepcopy(destination)
1✔
50
    else:
51
        result = destination
1✔
52
    for key, value in source.items():
1✔
53
        if isinstance(value, dict):
1✔
54
            # get node or create one
55
            node = result.setdefault(key, {})
1✔
56
            merge(value, node, False)
1✔
57
        else:
58
            if key in result and type(result[key]) is list:
1✔
59
                if key == 'tw' or key == 'sw':
1✔
60
                    if result[key][0] > value[0]:
1✔
61
                        result[key][0] = value[0]
1✔
62
                    if result[key][1] < value[1]:
1✔
63
                        result[key][1] = value[1]
1✔
64
            else:
65
                result[key] = value
1✔
66
    if main == True:
1✔
67
        log.debug("Merge Result")
1✔
68
        log.debug("\n" + pformat(result))
1✔
69
    return result
1✔
70

71
def get_models_json(json):
1✔
72
    model_json = {}
1✔
73
    model_json['Parameters'] = list(json['Parameters'].keys())
1✔
74
    model_json['Constants'] = list(json['Constants'].keys())
1✔
75
    model_json['Inputs'] = list(json['Inputs'].keys())
1✔
76
    model_json['Outputs'] = list(json['Outputs'].keys())
1✔
77
    model_json['Functions'] = list(json['Functions'].keys())
1✔
78
    model_json['Relations'] = list(json['Relations'].keys())
1✔
79
    return model_json
1✔
80

81
def check_model(json):
1✔
82
    all_inputs = json['Inputs'].keys()
1✔
83
    all_outputs = json['Outputs'].keys()
1✔
84

85
    from nnodely.basic.relation import MAIN_JSON
1✔
86
    subjson = MAIN_JSON
1✔
87
    for name in all_outputs:
1✔
88
        subjson = merge(subjson, subjson_from_output(json, name))
1✔
89
    needed_inputs = subjson['Inputs'].keys()
1✔
90
    extenal_inputs = set(all_inputs) - set(needed_inputs)
1✔
91

92
    check(all_inputs == needed_inputs, RuntimeError,
1✔
93
          f'Connect or close loop operation on the inputs {list(extenal_inputs)}, that are not used in the model.')
94
    return json
1✔
95

96
def binary_cheks(self, obj1, obj2, name):
1✔
97
    from nnodely.basic.relation import Stream, toStream
1✔
98
    obj1,obj2 = toStream(obj1),toStream(obj2)
1✔
99
    check(type(obj1) is Stream,TypeError,
1✔
100
          f"The type of {obj1} is {type(obj1)} and is not supported for add operation.")
101
    check(type(obj2) is Stream,TypeError,
1✔
102
          f"The type of {obj2} is {type(obj2)} and is not supported for add operation.")
103
    window_obj1 = get_window(obj1)
1✔
104
    window_obj2 = get_window(obj2)
1✔
105
    if window_obj1 is not None and window_obj2 is not None:
1✔
106
        check(window_obj1==window_obj2, TypeError,
1✔
107
              f"For {name} the time window type must match or None but they were {window_obj1} and {window_obj2}.")
108
        check(obj1.dim[window_obj1] == obj2.dim[window_obj2], ValueError,
1✔
109
              f"For {name} the time window must match or None but they were {window_obj1}={obj1.dim[window_obj1]} and {window_obj2}={obj2.dim[window_obj2]}.")
110
    check(obj1.dim['dim'] == obj2.dim['dim'] or obj1.dim == {'dim':1} or obj2.dim == {'dim':1}, ValueError,
1✔
111
          f"For {name} the dimension of {obj1.name} = {obj1.dim} must be the same of {obj2.name} = {obj2.dim}.")
112
    dim = obj1.dim | obj2.dim
1✔
113
    dim['dim'] = max(obj1.dim['dim'], obj2.dim['dim'])
1✔
114
    return obj1, obj2, dim
1✔
115

116
def subjson_from_relation(json, relation):
1✔
117
    json = copy.deepcopy(json)
1✔
118
    # Get all the inputs needed to compute a specific relation from the json graph
119
    inputs = set()
1✔
120
    relations = set()
1✔
121
    constants = set()
1✔
122
    parameters = set()
1✔
123
    functions = set()
1✔
124

125
    def search(rel):
1✔
126
        if rel in json['Inputs']:  # Found an input
1✔
127
            inputs.add(rel)
1✔
128
            if rel in json['Inputs']:
1✔
129
                if 'connect' in json['Inputs'][rel] and json['Inputs'][rel]['local'] == 1:
1✔
130
                    search(json['Inputs'][rel]['connect'])
1✔
131
                if 'closed_loop' in json['Inputs'][rel] and json['Inputs'][rel]['local'] == 1:
1✔
NEW
132
                    search(json['Inputs'][rel]['closed_loop'])
×
133
                # if 'init' in json['Inputs'][rel]:
134
                #     search(json['Inputs'][rel]['init'])
135
        elif rel in json['Constants']:  # Found a constant or parameter
1✔
136
            constants.add(rel)
1✔
137
        elif rel in json['Parameters']:
1✔
138
            parameters.add(rel)
1✔
139
        elif rel in json['Functions']:
1✔
140
            functions.add(rel)
1✔
141
            if 'params_and_consts' in json['Functions'][rel]:
1✔
142
                for sub_rel in json['Functions'][rel]['params_and_consts']:
1✔
143
                    search(sub_rel)
1✔
144
        elif rel in json['Relations']:  # Another relation
1✔
145
            relations.add(rel)
1✔
146
            for sub_rel in json['Relations'][rel][1]:
1✔
147
                search(sub_rel)
1✔
148
            for sub_rel in json['Relations'][rel][2:]:
1✔
149
                if json['Relations'][rel][0] in ('Fir', 'Linear'):
1✔
150
                    search(sub_rel)
1✔
151
                if json['Relations'][rel][0] in ('Fuzzify'):
1✔
152
                    search(sub_rel)
1✔
153
                if json['Relations'][rel][0] in ('ParamFun'):
1✔
154
                    search(sub_rel)
1✔
155

156
    search(relation)
1✔
157
    from nnodely.basic.relation import MAIN_JSON
1✔
158
    sub_json = copy.deepcopy(MAIN_JSON)
1✔
159
    sub_json['Relations'] = {key: value for key, value in json['Relations'].items() if key in relations}
1✔
160
    sub_json['Inputs'] = {key: value for key, value in json['Inputs'].items() if key in inputs}
1✔
161
    sub_json['Constants'] = {key: value for key, value in json['Constants'].items() if key in constants}
1✔
162
    sub_json['Parameters'] = {key: value for key, value in json['Parameters'].items() if key in parameters}
1✔
163
    sub_json['Functions'] = {key: value for key, value in json['Functions'].items() if key in functions}
1✔
164
    sub_json['Outputs'] = {}
1✔
165
    sub_json['Info'] = {}
1✔
166
    return sub_json
1✔
167

168

169
def subjson_from_output(json, outputs:str|list):
1✔
170
    json = copy.deepcopy(json)
1✔
171
    from nnodely.basic.relation import MAIN_JSON
1✔
172
    sub_json = copy.deepcopy(MAIN_JSON)
1✔
173
    if type(outputs) is str:
1✔
174
        outputs = [outputs]
1✔
175
    for output in outputs:
1✔
176
        sub_json = merge(sub_json, subjson_from_relation(json,json['Outputs'][output]))
1✔
177
        sub_json['Outputs'][output] = json['Outputs'][output]
1✔
178
    return sub_json
1✔
179

180
def subjson_from_model(json, models:str|list):
1✔
181
    from nnodely.basic.relation import MAIN_JSON
1✔
182
    json = copy.deepcopy(json)
1✔
183
    sub_json = copy.deepcopy(MAIN_JSON)
1✔
184
    models_names = set([json['Models']]) if type(json['Models']) is str else set(json['Models'].keys())
1✔
185
    if type(models) is str or len(models) == 1:
1✔
186
        if len(models) == 1:
1✔
187
            models = models[0]
1✔
188
        check(models in models_names, AttributeError, f"Model [{models}] not found!")
1✔
189
        if type(json['Models']) is str:
1✔
190
            outputs = set(json['Outputs'].keys())
1✔
191
        else:
192
            outputs = set(json['Models'][models]['Outputs'])
1✔
193
        sub_json['Models'] = models
1✔
194
    else:
195
        outputs = set()
1✔
196
        sub_json['Models'] = {}
1✔
197
        for model in models:
1✔
198
            check(model in models_names, AttributeError, f"Model [{model}] not found!")
1✔
199
            outputs |= set(json['Models'][model]['Outputs'])
1✔
200
            sub_json['Models'][model] = {key: value for key, value in json['Models'][model].items()}
1✔
201

202
    # Remove the extern connections not keys in the graph
203
    final_json = merge(sub_json, subjson_from_output(json, outputs))
1✔
204
    for key, value in final_json['Inputs'].items():
1✔
205
        if 'connect' in value and (value['local'] == 0 and value['connect'] not in final_json['Relations'].keys()):
1✔
206
            del final_json['Inputs'][key]['connect']
1✔
207
            del final_json['Inputs'][key]['local']
1✔
208
            log.warning(f'The input {key} is "connect" outside the model connection removed for subjson')
1✔
209
        if 'closedLoop' in value and (value['local'] == 0 and value['closedLoop'] not in final_json['Relations'].keys()):
1✔
210
            del final_json['Inputs'][key]['closedLoop']
1✔
211
            del final_json['Inputs'][key]['local']
1✔
212
            log.warning(f'The input {key} is "closedLoop" outside the model connection removed for subjson')
1✔
213
    return final_json
1✔
214

215
def stream_to_str(obj, type = 'Stream'):
1✔
216
    from nnodely.visualizer.emptyvisualizer import color, GREEN
1✔
217
    from pprint import pformat
1✔
218
    stream = f" {type} "
1✔
219
    stream_name = f" {obj.name} {obj.dim} "
1✔
220

221
    title = color((stream).center(80, '='), GREEN, True)
1✔
222
    json = color(pformat(obj.json), GREEN)
1✔
223
    stream = color((stream_name).center(80, '-'), GREEN, True)
1✔
224
    return title + '\n' + json + '\n' + stream
1✔
225

226
def plot_structure(json, filename='nnodely_graph', library='matplotlib', view=True):
1✔
227
        #json = self.modely.json if json is None else json
228
        # if json is None:
229
        #     raise ValueError("No JSON model definition provided. Please provide a valid JSON model definition.")
230
        if library not in ['matplotlib', 'graphviz']:
1✔
231
            raise ValueError("Invalid library specified. Use 'matplotlib' or 'graphviz'.")
1✔
232
        if library == 'matplotlib':
1✔
233
            plot_matplotlib_structure(json, filename, view=view)
1✔
NEW
234
        elif library == 'graphviz':
×
NEW
235
            plot_graphviz_structure(json, filename, view=view)
×
236

237
def plot_matplotlib_structure(json, filename='nnodely_graph', view=True):
1✔
238
    import matplotlib.pyplot as plt
1✔
239
    from matplotlib import patches
1✔
240
    from matplotlib.lines import Line2D
1✔
241
    layer_positions = {}
1✔
242
    x, y = 0, 0  # Initial position
1✔
243
    dy, dx = 1.5, 2.5  # Spacing
1✔
244

245
    ## Layer Inputs: 
246
    for input_name, input_type in json['Inputs'].items():
1✔
247
        layer_positions[input_name] = (x, y)
1✔
248
        y -= dy
1✔
249
    for constant_name in json['Constants'].keys():
1✔
250
        layer_positions[constant_name] = (x, y)
1✔
251
        y -= dy
1✔
252
    y_limit = abs(y)
1✔
253

254
    # Layers Relations:
255
    available_inputs = list(json['Inputs'].keys() | json['Constants'].keys())
1✔
256
    available_outputs = list(set(json['Outputs'].values()))
1✔
257
    while available_outputs:
1✔
258
        x += dx
1✔
259
        y = 0
1✔
260
        inputs_to_add, outputs_to_remove = [], []
1✔
261
        for relation_name, (relation_type, dependencies, *_) in json['Relations'].items():
1✔
262
            if all(dep in available_inputs for dep in dependencies) and (relation_name not in available_inputs):
1✔
263
                inputs_to_add.append(relation_name)
1✔
264
                if relation_name in available_outputs:
1✔
265
                    outputs_to_remove.append(relation_name)
1✔
266
                layer_positions[relation_name] = (x, y)
1✔
267
                y -= dy
1✔
268
        y_limit = max(y_limit, abs(y))
1✔
269
        available_inputs.extend(inputs_to_add)
1✔
270
        available_outputs = [out for out in available_outputs if out not in outputs_to_remove]
1✔
271

272
    ## Layer Outputs: 
273
    x += dx
1✔
274
    y = 0
1✔
275
    for idx, output_name in enumerate(json['Outputs'].keys()):
1✔
276
        layer_positions[output_name] = (x, y)
1✔
277
        y -= dy  # Move down for the next input
1✔
278
    x_limit = abs(x)
1✔
279
    y_limit = max(y_limit, abs(y))
1✔
280

281
    # Create the plot
282
    fig, ax = plt.subplots(figsize=(x_limit, y_limit))
1✔
283
    #fig.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05)
284

285
    # Plot rectangles for each layer
286
    colors, labels = ['lightgreen', 'lightblue', 'orange', 'lightgray'], ['Inputs', 'Relations', 'Outputs', 'Constants']
1✔
287
    legend_info = [patches.Patch(facecolor=color, edgecolor='black', label=label) for color, label in zip(colors, labels)]
1✔
288
    for layer in (json['Inputs'].keys() | json['Outputs'].keys() | json['Relations'].keys() | json['Constants'].keys()):
1✔
289
        x1, y1 = layer_positions[layer]
1✔
290
        if layer in json['Inputs'].keys():
1✔
291
            color = 'lightgreen'
1✔
292
            tag = f'{layer}\ndim: {json["Inputs"][layer]["dim"]}\nWindow: {json["Inputs"][layer]["ntot"]}'
1✔
293
        elif layer in json['Outputs'].keys():
1✔
294
            color = 'orange'
1✔
295
            tag = layer
1✔
296
        elif layer in json['Constants'].keys():
1✔
297
            color = 'lightgray'
1✔
298
            tag = f'{layer}\ndim: {json["Constants"][layer]["dim"]}'
1✔
299
        else:
300
            color = 'lightblue'
1✔
301
            tag = f'{json["Relations"][layer][0]}\n({layer})'
1✔
302
        rect = patches.Rectangle((x1, y1), 2, 1, edgecolor='black', facecolor=color)
1✔
303
        ax.add_patch(rect)
1✔
304
        ax.text(x1 + 1, y1 + 0.5, f"{tag}", ha='center', va='center', fontsize=8, fontweight='bold')
1✔
305

306
    # Draw arrows for dependencies
307
    for layer, (_, dependencies, *_) in json['Relations'].items():
1✔
308
        x1, y1 = layer_positions[layer]  # Get position of the current layer
1✔
309
        for dep in dependencies:
1✔
310
            if dep in layer_positions:
1✔
311
                x2, y2 = layer_positions[dep]  # Get position of the dependent layer
1✔
312
                ax.annotate("", xy=(x1, y1), xytext=(x2 + 2, y2 + 0.5), arrowprops=dict(arrowstyle="->", color='black', lw=1))
1✔
313
    for out_name, rel_name in json['Outputs'].items():
1✔
314
        x1, y1 = layer_positions[out_name]
1✔
315
        x2, y2 = layer_positions[rel_name]
1✔
316
        ax.annotate("", xy=(x1, y1 + 0.5), xytext=(x2 + 2, y2 + 0.5),
1✔
317
                    arrowprops=dict(arrowstyle="->", color='black', lw=1))
318
    for key, state in json['Inputs'].items():
1✔
319
        if 'closedLoop' in state.keys():
1✔
320
            x1, y1 = layer_positions[key]
1✔
321
            x2, y2 = layer_positions[state['closedLoop']]
1✔
322
            #ax.annotate("", xy=(x2+1, y2), xytext=(x2+1, y_limit), arrowprops=dict(arrowstyle="-", color='red', lw=1, linestyle='dashed'))
323
            ax.add_patch(patches.FancyArrowPatch((x2+1, y2), (x2+1, -y_limit), arrowstyle='-', mutation_scale=15, color='red', linestyle='dashed'))
1✔
324
            ax.add_patch(patches.FancyArrowPatch((x2+1, -y_limit), (x1-1, -y_limit), arrowstyle='-', mutation_scale=15, color='red', linestyle='dashed'))
1✔
325
            ax.add_patch(patches.FancyArrowPatch((x1-1, -y_limit), (x1-1, y1+0.5), arrowstyle='-', mutation_scale=15, color='red', linestyle='dashed'))
1✔
326
            ax.add_patch(patches.FancyArrowPatch((x1-1, y1+0.5), (x1, y1+0.5), arrowstyle='->', mutation_scale=15, color='red', linestyle='dashed'))
1✔
327
        elif 'connect' in state.keys():
1✔
328
            x1, y1 = layer_positions[key]
1✔
329
            x2, y2 = layer_positions[state['connect']]
1✔
330
            ax.add_patch(patches.FancyArrowPatch((x1, y1), (x2, y2), arrowstyle='->', mutation_scale=15, color='green', linestyle='dashed'))
1✔
331
        
332
    legend_info.extend([Line2D([0], [0], color='black', lw=2, label='Dependency'),
1✔
333
                        Line2D([0], [0], color='red', lw=2, linestyle='dashed', label='Closed Loop'),
334
                        Line2D([0], [0], color='green', lw=2, linestyle='dashed', label='Connect')])
335

336
    # Adjust the plot limits
337
    ax.set_xlim(-dx, x_limit+dx)
1✔
338
    ax.set_ylim(-y_limit, dy)
1✔
339
    ax.set_aspect('equal')
1✔
340
    ax.legend(handles=legend_info, loc='lower right')
1✔
341
    ax.axis('off')  # Hide axes
1✔
342

343
    plt.title(f"Neural Network Diagram - Sampling [{json['Info']['SampleTime']}]", fontsize=12, fontweight='bold')
1✔
344
    ## Save the figure
345
    plt.savefig(filename, format="png", bbox_inches='tight')
1✔
346
    if view:
1✔
NEW
347
        plt.show()
×
348

349
def plot_graphviz_structure(json, filename='nnodely_graph', view=True): # pragma: no cover
350
    import shutil
351
    from graphviz import view
352
    from graphviz import Digraph
353

354
    # Check if Graphviz is installed
355
    if shutil.which('dot') is None:
356
        # raise RuntimeError(
357
        #     "Graphviz does not appear to be installed on your system. "
358
        #     "Please install it from https://graphviz.org/download/"
359
        # )
360
        log.warning(
361
            "Graphviz does not appear to be installed on your system. "
362
            "Please install it from https://graphviz.org/download/"
363
        )
364
        return
365
    
366
    dot = Digraph(comment='Structured Neural Network')
367

368
    # Set graph attributes for top-down layout and style
369
    dot.attr(rankdir='LR', size='21')  
370
    dot.attr('node', shape='box', style='filled', color='lightgray', fontname='Helvetica')
371

372
    # Add metadata/info box
373
    if 'Info' in json:
374
        info = json['Info']
375
        info_text = '\n'.join([f"{k}: {v}" for k, v in info.items()])
376
        dot.node('INFO_BOX', label=f"Model Info\n{info_text}", shape='note', fillcolor='white', fontsize='10')
377

378
    # Add input nodes
379
    for inp, data in json['Inputs'].items():
380
        dim = data['dim']
381
        window = data['sw']
382
        label = f"{inp}\nDim: {dim}\nWindow: {window}"
383
        dot.node(inp, label=label, fillcolor='lightgreen')
384
        if 'connect' in data.keys():
385
            dot.edge(data['connect'], inp, label='connect', color='blue', fontcolor='blue')
386
        if 'closedLoop' in data.keys():
387
            dot.edge(data['closedLoop'], inp, label='closedLoop', color='red', fontcolor='red')
388

389
    # Add constant nodes
390
    if 'Constants' in json:
391
        for const, data in json['Constants'].items():
392
            dim = data['dim']
393
            label = f"{const}\nDim: {dim}"
394
            dot.node(const, label=label, fillcolor='lightgray')
395

396
    # Add relation nodes
397
    for name, rel in json['Relations'].items():
398
        op_type = rel[0]
399
        parents = rel[1]
400
        param1 = rel[2] if len(rel) > 2 else None
401
        param2 = rel[3] if len(rel) > 3 else None
402
        label = f"{name}\nType: {op_type}"
403
        dot.node(name, label=label, fillcolor='lightblue')
404
        for i in [param1,param2]:
405
            if isinstance(i, str):
406
                if i in json['Parameters']:
407
                    param_dim = json['Parameters'][i]['dim']
408
                    dot.node(i, label=f"{i}\nDim: {param_dim}", shape='ellipse', fillcolor='orange')
409
                    dot.edge(i, name, label='Parameter', color='orange', fontcolor='orange')
410
                elif i in json['Functions']:
411
                    dot.node(i, label=f"{param1}", shape='ellipse', fillcolor='darkorange')
412
                    dot.edge(i, name, label='function', color='darkorange', fontcolor='darkorange')
413
        for parent in parents:
414
            dot.edge(parent, name)
415

416
    # Add output nodes
417
    for out, rel in json['Outputs'].items():
418
        dot.node(out, fillcolor='lightcoral')
419
        dot.edge(rel, out)
420

421
    # Add Minimize nodes if present
422
    if 'Minimizers' in json:
423
        for name, rel in json['Minimizers'].items():
424
            rel_a, rel_b = rel['A'], rel['B']
425
            loss = rel['loss']
426
            dot.node(name, label=f"{name}\nLoss:{loss}", shape='ellipse', fillcolor='purple')
427
            dot.edge(rel_a, name, label='Minimize', color='purple', fontcolor='purple')
428
            dot.edge(rel_b, name, label='Minimize', color='purple', fontcolor='purple')
429

430
    # Add a legend as a subgraph
431
    # with dot.subgraph(name='cluster_legend') as legend:
432
    #     legend.attr(label='Legend', style='dashed')
433
    #     legend.node('LegendInput', 'Inputs', shape='box', fillcolor='lightgreen', style='filled')
434
    #     legend.node('LegendRel', 'Relation', shape='box', fillcolor='lightblue', style='filled')
435
    #     legend.node('LegendOutput', 'Outputs', shape='box', fillcolor='lightcoral', style='filled')
436
    #     # Hide the edges inside the legend box
437
    #     legend.attr('edge', style='invis')
438
    #     legend.edge('LegendInput', 'LegendRel')
439
    #     legend.edge('LegendRel', 'LegendOutput')
440

441
    # Render the graph
442
    dot.render(filename=filename, view=view, format='svg')  # opens in default viewer and saves as SVG
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc