• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

DHARPA-Project / kiara_plugin.network_analysis / 16113187676

07 Jul 2025 09:28AM UTC coverage: 54.789% (+0.1%) from 54.641%
16113187676

push

github

makkus
build: add marimo depenendcy

84 of 163 branches covered (51.53%)

Branch coverage included in aggregate %.

591 of 1069 relevant lines covered (55.29%)

2.76 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

56.54
/src/kiara_plugin/network_analysis/utils/__init__.py
1
# -*- coding: utf-8 -*-
2
#  Copyright (c) 2022, Markus Binsteiner
3
#
4
#  Mozilla Public License, version 2.0 (see LICENSE or https://www.mozilla.org/en-US/MPL/2.0/)
5
from typing import (
5✔
6
    TYPE_CHECKING,
7
    Any,
8
    Dict,
9
    Hashable,
10
    Iterable,
11
    List,
12
    Tuple,
13
    Union,
14
)
15

16
import duckdb
5✔
17

18
from kiara.exceptions import KiaraException
5✔
19
from kiara_plugin.network_analysis.defaults import (
5✔
20
    COMPONENT_ID_COLUMN_NAME,
21
    CONNECTIONS_COLUMN_NAME,
22
    CONNECTIONS_MULTI_COLUMN_NAME,
23
    COUNT_DIRECTED_COLUMN_NAME,
24
    COUNT_IDX_DIRECTED_COLUMN_NAME,
25
    COUNT_IDX_UNDIRECTED_COLUMN_NAME,
26
    COUNT_UNDIRECTED_COLUMN_NAME,
27
    EDGE_ID_COLUMN_NAME,
28
    IN_DIRECTED_COLUMN_NAME,
29
    IN_DIRECTED_MULTI_COLUMN_NAME,
30
    LABEL_ALIAS_NAMES,
31
    LABEL_COLUMN_NAME,
32
    NODE_ID_ALIAS_NAMES,
33
    NODE_ID_COLUMN_NAME,
34
    OUT_DIRECTED_COLUMN_NAME,
35
    OUT_DIRECTED_MULTI_COLUMN_NAME,
36
    SOURCE_COLUMN_ALIAS_NAMES,
37
    SOURCE_COLUMN_NAME,
38
    TARGET_COLUMN_ALIAS_NAMES,
39
    TARGET_COLUMN_NAME,
40
    UNWEIGHTED_DEGREE_CENTRALITY_COLUMN_NAME,
41
    UNWEIGHTED_DEGREE_CENTRALITY_MULTI_COLUMN_NAME,
42
)
43

44
if TYPE_CHECKING:
5✔
45
    import networkx as nx
×
46
    import polars as pl
×
47
    import pyarrow as pa
×
48
    from sqlalchemy import MetaData, Table  # noqa
×
49

50
    from kiara.models.values.value import Value
×
51
    from kiara_plugin.network_analysis.models import NetworkData
×
52
    from kiara_plugin.tabular.models import KiaraTable
×
53

54

55
def extract_networkx_nodes_as_table(
5✔
56
    graph: "nx.Graph",
57
    label_attr_name: Union[str, None, Iterable[str]] = None,
58
    ignore_attributes: Union[None, Iterable[str]] = None,
59
) -> Tuple["pa.Table", Dict[Hashable, int]]:
60
    """Extract the nodes of a networkx graph as a pyarrow table.
61

62
    Arguments:
63
        graph: the networkx graph
64
        label_attr_name: the name of the node attribute that should be used as label. If None, the node id is used.
65
        ignore_attributes: a list of node attributes that should be ignored and not added to the table
66

67
    Returns:
68
        a tuple with the table and a map containing the original node id as key and the newly created internal node id (int) as value
69
    """
70
    # adapted from networx code
71
    # License: 3-clause BSD license
72
    # Copyright (C) 2004-2022, NetworkX Developers
73

74
    import pyarrow as pa
×
75

76
    # nan = float("nan")
77

78
    nodes: Dict[str, List[Any]] = {
×
79
        NODE_ID_COLUMN_NAME: [],
80
        LABEL_COLUMN_NAME: [],
81
    }
82
    nodes_map = {}
×
83

84
    for i, (node_id, node_data) in enumerate(graph.nodes(data=True)):
×
85
        nodes[NODE_ID_COLUMN_NAME].append(i)
×
86
        if label_attr_name is None:
×
87
            nodes[LABEL_COLUMN_NAME].append(str(node_id))
×
88
        elif isinstance(label_attr_name, str):
×
89
            label = node_data.get(label_attr_name, None)
×
90
            if label:
×
91
                nodes[LABEL_COLUMN_NAME].append(str(label))
×
92
            else:
93
                nodes[LABEL_COLUMN_NAME].append(str(node_id))
×
94
        else:
95
            label_final = None
×
96
            for label in label_attr_name:
×
97
                label_final = node_data.get(label, None)
×
98
                if label_final:
×
99
                    break
×
100
            if not label_final:
×
101
                label_final = node_id
×
102
            nodes[LABEL_COLUMN_NAME].append(str(label_final))
×
103

104
        nodes_map[node_id] = i
×
105
        for k in node_data.keys():
×
106
            if ignore_attributes and k in ignore_attributes:
×
107
                continue
×
108

109
            if k.startswith("_"):
×
110
                raise KiaraException(
×
111
                    "Graph contains node column name starting with '_'. This is reserved for internal use, and not allowed."
112
                )
113

114
            v = node_data.get(k, None)
×
115
            nodes.setdefault(k, []).append(v)
×
116

117
    nodes_table = pa.Table.from_pydict(mapping=nodes)
×
118

119
    return nodes_table, nodes_map
×
120

121

122
def extract_networkx_edges_as_table(
5✔
123
    graph: "nx.Graph", node_id_map: Dict[Hashable, int]
124
) -> "pa.Table":
125
    """Extract the edges of this graph as a pyarrow table.
126

127
    The provided `node_id_map` might be modified if a node id is not yet in the map.
128

129
    Args:
130
        graph: The graph to extract edges from.
131
        node_id_map: A mapping from (original) node ids to (kiara-internal) (integer) node-ids.
132
    """
133

134
    # adapted from networx code
135
    # License: 3-clause BSD license
136
    # Copyright (C) 2004-2022, NetworkX Developers
137

138
    import pyarrow as pa
×
139

140
    if node_id_map is None:
×
141
        node_id_map = {}
×
142

143
    # nan = float("nan")
144

145
    max_node_id = max(node_id_map.values())  # TODO: could we just use len(node_id_map)?
×
146
    edge_columns: Dict[str, List[int]] = {
×
147
        SOURCE_COLUMN_NAME: [],
148
        TARGET_COLUMN_NAME: [],
149
    }
150

151
    for source, target, edge_data in graph.edges(data=True):
×
152
        if source not in node_id_map.keys():
×
153
            max_node_id += 1
×
154
            node_id_map[source] = max_node_id
×
155
        if target not in node_id_map.keys():
×
156
            max_node_id += 1
×
157
            node_id_map[target] = max_node_id
×
158

159
        edge_columns[SOURCE_COLUMN_NAME].append(node_id_map[source])
×
160
        edge_columns[TARGET_COLUMN_NAME].append(node_id_map[target])
×
161

162
        for k in edge_data.keys():
×
163
            if k.startswith("_"):
×
164
                raise KiaraException(
×
165
                    "Graph contains edge column name starting with '_'. This is reserved for internal use, and not allowed."
166
                )
167

168
            v = edge_data.get(k, None)
×
169
            edge_columns.setdefault(k, []).append(v)  # type: ignore
×
170

171
    edges_table = pa.Table.from_pydict(mapping=edge_columns)
×
172

173
    return edges_table
×
174

175

176
def augment_nodes_table_with_connection_counts(
5✔
177
    nodes_table: Union["pa.Table", "pl.DataFrame"],
178
    edges_table: Union["pa.Table", "pl.DataFrame"],
179
) -> "pa.Table":
180
    import duckdb
5✔
181

182
    try:
5✔
183
        nodes_column_names = nodes_table.column_names  # type: ignore
5✔
184
    except Exception:
×
185
        nodes_column_names = nodes_table.columns  # type: ignore
×
186

187
    node_attr_columns = [x for x in nodes_column_names if not x.startswith("_")]
5✔
188
    if node_attr_columns:
5✔
189
        other_columns = ", " + ", ".join(node_attr_columns)
5✔
190
    else:
191
        other_columns = ""
×
192

193
    # we can avoid 'COUNT(*)' calls in the following  query
194
    nodes_table_rows = len(nodes_table)
5✔
195

196
    query = f"""
5✔
197
    SELECT
198
         {NODE_ID_COLUMN_NAME},
199
         {LABEL_COLUMN_NAME},
200
         COALESCE(e1.{IN_DIRECTED_COLUMN_NAME}, 0) + COALESCE(e3.{OUT_DIRECTED_COLUMN_NAME}, 0) as {CONNECTIONS_COLUMN_NAME},
201
         (COALESCE(e1._in_edges, 0) + COALESCE(e3._out_edges, 0)) / {nodes_table_rows} AS _degree_centrality,
202
         COALESCE(e2.{IN_DIRECTED_MULTI_COLUMN_NAME}, 0) + COALESCE(e4.{OUT_DIRECTED_MULTI_COLUMN_NAME}, 0) as {CONNECTIONS_MULTI_COLUMN_NAME},
203
         COALESCE(e1.{IN_DIRECTED_COLUMN_NAME}, 0) as {IN_DIRECTED_COLUMN_NAME},
204
         COALESCE(e2.{IN_DIRECTED_MULTI_COLUMN_NAME}, 0) as {IN_DIRECTED_MULTI_COLUMN_NAME},
205
         COALESCE(e3.{OUT_DIRECTED_COLUMN_NAME}, 0) as {OUT_DIRECTED_COLUMN_NAME},
206
         COALESCE(e4.{OUT_DIRECTED_MULTI_COLUMN_NAME}, 0) as {OUT_DIRECTED_MULTI_COLUMN_NAME}
207
         {other_columns}
208
         FROM nodes_table n
209
         left join
210
           (SELECT {TARGET_COLUMN_NAME}, {COUNT_IDX_DIRECTED_COLUMN_NAME}, COUNT(*) as {IN_DIRECTED_COLUMN_NAME} from edges_table GROUP BY {TARGET_COLUMN_NAME}, {COUNT_IDX_DIRECTED_COLUMN_NAME}) e1
211
           on n.{NODE_ID_COLUMN_NAME} = e1.{TARGET_COLUMN_NAME} and e1.{COUNT_IDX_DIRECTED_COLUMN_NAME} = 1
212
         left join
213
           (SELECT {TARGET_COLUMN_NAME}, COUNT(*) as {IN_DIRECTED_MULTI_COLUMN_NAME} from edges_table GROUP BY {TARGET_COLUMN_NAME}) e2
214
           on n.{NODE_ID_COLUMN_NAME} = e2.{TARGET_COLUMN_NAME}
215
         left join
216
           (SELECT {SOURCE_COLUMN_NAME}, {COUNT_IDX_DIRECTED_COLUMN_NAME}, COUNT(*) as {OUT_DIRECTED_COLUMN_NAME} from edges_table GROUP BY {SOURCE_COLUMN_NAME}, {COUNT_IDX_DIRECTED_COLUMN_NAME}) e3
217
           on n.{NODE_ID_COLUMN_NAME} = e3.{SOURCE_COLUMN_NAME} and e3.{COUNT_IDX_DIRECTED_COLUMN_NAME} = 1
218
         left join
219
           (SELECT {SOURCE_COLUMN_NAME}, COUNT(*) as {OUT_DIRECTED_MULTI_COLUMN_NAME} from edges_table GROUP BY {SOURCE_COLUMN_NAME}) e4
220
           on n.{NODE_ID_COLUMN_NAME} = e4.{SOURCE_COLUMN_NAME}
221
        ORDER BY {NODE_ID_COLUMN_NAME}
222
    """
223

224
    nodes_table_result = duckdb.sql(query)  # noqa
5✔
225

226
    centrality_query = f"""
5✔
227
    SELECT
228
         {NODE_ID_COLUMN_NAME},
229
         {LABEL_COLUMN_NAME},
230
         {CONNECTIONS_COLUMN_NAME},
231
         {CONNECTIONS_COLUMN_NAME} / (SELECT COUNT(*) FROM nodes_table_result) AS {UNWEIGHTED_DEGREE_CENTRALITY_COLUMN_NAME},
232
         {CONNECTIONS_MULTI_COLUMN_NAME},
233
         {CONNECTIONS_MULTI_COLUMN_NAME} / (SELECT COUNT(*) FROM nodes_table_result) AS {UNWEIGHTED_DEGREE_CENTRALITY_MULTI_COLUMN_NAME},
234
         {IN_DIRECTED_COLUMN_NAME},
235
         {IN_DIRECTED_MULTI_COLUMN_NAME},
236
         {OUT_DIRECTED_COLUMN_NAME},
237
         {OUT_DIRECTED_MULTI_COLUMN_NAME}
238
         {other_columns}
239
    FROM nodes_table_result
240
    """
241

242
    result = duckdb.sql(centrality_query)
5✔
243

244
    nodes_table_augmented = result.arrow()
5✔
245
    return nodes_table_augmented
5✔
246

247

248
def augment_edges_table_with_id_and_weights(
5✔
249
    edges_table: Union["pa.Table", "pl.DataFrame"],
250
) -> "pa.Table":
251
    """Augment the edges table with additional pre-computed columns for directed and undirected weights.."""
252

253
    import duckdb
5✔
254

255
    try:
5✔
256
        column_names = edges_table.column_names  # type: ignore
5✔
257
    except Exception:
×
258
        column_names = edges_table.columns  # type: ignore
×
259

260
    edge_attr_columns = [x for x in column_names if not x.startswith("_")]
5✔
261
    if edge_attr_columns:
5✔
262
        other_columns = ", " + ", ".join(edge_attr_columns)
5✔
263
    else:
264
        other_columns = ""
×
265

266
    query = f"""
5✔
267
    SELECT
268
      ROW_NUMBER() OVER () -1 as {EDGE_ID_COLUMN_NAME},
269
      {SOURCE_COLUMN_NAME},
270
      {TARGET_COLUMN_NAME},
271
      COUNT(*) OVER (PARTITION BY {SOURCE_COLUMN_NAME}, {TARGET_COLUMN_NAME}) as {COUNT_DIRECTED_COLUMN_NAME},
272
      ROW_NUMBER(*) OVER (PARTITION BY {SOURCE_COLUMN_NAME}, {TARGET_COLUMN_NAME}) as {COUNT_IDX_DIRECTED_COLUMN_NAME},
273
      COUNT(*) OVER (PARTITION BY LEAST({SOURCE_COLUMN_NAME}, {TARGET_COLUMN_NAME}), GREATEST({SOURCE_COLUMN_NAME}, {TARGET_COLUMN_NAME})) as {COUNT_UNDIRECTED_COLUMN_NAME},
274
      ROW_NUMBER(*) OVER (PARTITION BY LEAST({SOURCE_COLUMN_NAME}, {TARGET_COLUMN_NAME}), GREATEST({SOURCE_COLUMN_NAME}, {TARGET_COLUMN_NAME})) as {COUNT_IDX_UNDIRECTED_COLUMN_NAME}
275
      {other_columns}
276
    FROM edges_table"""
277

278
    result = duckdb.sql(query)
5✔
279
    edges_table_augmented = result.arrow()
5✔
280

281
    return edges_table_augmented
5✔
282

283

284
def augment_tables_with_component_ids(
5✔
285
    nodes_table: "pa.Table", edges_table: "pa.Table"
286
) -> Tuple["pa.Table", "pa.Table"]:
287
    """Augment the nodes and edges table with a component id column.
288

289
    The component id is a unique id for each connected component in the graph. The id of thte component is
290
    assigned in a deterministic way, based on the size of the components. The largest component gets id 0, the
291
    second largest 1, and so on.
292

293
    Should 2 components have the same amount of nodes, the component id is not deterministic anymore.
294
    """
295

296
    import polars as pl
5✔
297
    import pyarrow as pa
5✔
298
    import rustworkx as rx
5✔
299

300
    # create rustworkx graph
301
    graph = rx.PyGraph(multigraph=False)
5✔
302

303
    nodes_df = pl.from_arrow(nodes_table)
5✔
304
    for row in nodes_df.select(NODE_ID_COLUMN_NAME).rows(named=True):
5✔
305
        node_id = row[NODE_ID_COLUMN_NAME]
5✔
306
        graph_node_id = graph.add_node(node_id)
5✔
307
        assert node_id == graph_node_id
5✔
308

309
    edges_df = pl.from_arrow(edges_table)
5✔
310
    for row in edges_df.select(SOURCE_COLUMN_NAME, TARGET_COLUMN_NAME).rows(named=True):
5✔
311
        if row[SOURCE_COLUMN_NAME] == row[TARGET_COLUMN_NAME]:
5✔
312
            continue
5✔
313

314
        source_id = row[SOURCE_COLUMN_NAME]
5✔
315
        target_id = row[TARGET_COLUMN_NAME]
5✔
316

317
        graph.add_edge(source_id, target_id, None)
5✔
318

319
    components = rx.connected_components(graph)
5✔
320
    nodes_idx = 0
5✔
321
    for col_name in nodes_table.column_names:
5✔
322
        if col_name.startswith("_"):
5✔
323
            nodes_idx += 1
5✔
324
            continue
5✔
325

326
    if len(components) == 1:
5✔
327
        components_column_nodes = pa.array([0] * len(nodes_table), type=pa.int64())
5✔
328
        nodes = nodes_table.add_column(
5✔
329
            nodes_idx, COMPONENT_ID_COLUMN_NAME, components_column_nodes
330
        )
331
        components_column_edges = pa.array([0] * len(edges_table), type=pa.int64())
5✔
332

333
        edges_idx = 0
5✔
334
        for col_name in edges_table.column_names:
5✔
335
            if col_name.startswith("_"):
5✔
336
                edges_idx += 1
5✔
337
                continue
5✔
338
        edges = edges_table.add_column(
5✔
339
            edges_idx, COMPONENT_ID_COLUMN_NAME, components_column_edges
340
        )
341
        return nodes, edges
5✔
342

343
    node_components = {}
5✔
344
    for idx, component in enumerate(sorted(components, key=len, reverse=True)):
5✔
345
        for node_id in component:
5✔
346
            node_components[node_id] = idx
5✔
347

348
    if len(node_components) != graph.num_nodes():
5✔
349
        raise KiaraException(
×
350
            "Number of nodes in component map does not match number of nodes in network data. This is most likely a bug."
351
        )
352

353
    components_column_nodes = pa.array(
5✔
354
        (node_components[node_id] for node_id in sorted(node_components.keys())),
355
        type=pa.int64(),
356
    )
357
    nodes = nodes_table.add_column(
5✔
358
        nodes_idx, COMPONENT_ID_COLUMN_NAME, components_column_nodes
359
    )
360

361
    try:
5✔
362
        column_names = edges_table.column_names  # type: ignore
5✔
363
    except Exception:
×
364
        column_names = edges_table.columns  # type: ignore
×
365

366
    computed_attr_columns = [x for x in column_names if x.startswith("_")]
5✔
367
    computed_columns = ", ".join(computed_attr_columns)
5✔
368
    edge_attr_columns = [x for x in column_names if not x.startswith("_")]
5✔
369
    if edge_attr_columns:
5✔
370
        other_columns = ", " + ", ".join(edge_attr_columns)
5✔
371
    else:
372
        other_columns = ""
×
373

374
    # a query that looks up the value of a SOURCE_COLUMN_NAME in edges_table in the
375
    # NODE_ID_COLUMN_NAME of nodes, and returns the component id from the nodes table
376
    query = f"""
5✔
377
    SELECT
378
        {computed_columns},
379
        n.{COMPONENT_ID_COLUMN_NAME} as {COMPONENT_ID_COLUMN_NAME}
380
        {other_columns}
381
    FROM edges_table e
382
    JOIN nodes n ON e.{SOURCE_COLUMN_NAME} = n.{NODE_ID_COLUMN_NAME}
383
    """
384

385
    edges = duckdb.sql(query)
5✔
386

387
    return nodes, edges.arrow()
5✔
388

389

390
def extract_network_data(network_data: Union["Value", "NetworkData"]) -> "NetworkData":
5✔
391
    from kiara.models.values.value import Value
×
392

393
    if isinstance(network_data, Value):
×
394
        assert network_data.data_type_name == "network_data"
×
395
        network_data = network_data.data
×
396
    return network_data
×
397

398

399
def guess_column_name(
5✔
400
    table: Union["pa.Table", "KiaraTable", "Value"], suggestions: List[str]
401
) -> Union[str, None]:
402
    column_names: Union[List[str], None] = None
5✔
403

404
    if hasattr(table, "column_names"):
5✔
405
        column_names = table.column_names
5✔
406
    else:
407
        from kiara.models.values.value import Value
×
408

409
        if isinstance(table, Value):
×
410
            table_instance = table.data
×
411
            if hasattr(table_instance, "column_names"):
×
412
                column_names = table_instance.column_names
×
413

414
    if not column_names:
5✔
415
        return None
×
416

417
    for suggestion in suggestions:
5✔
418
        if suggestion in column_names:
5✔
419
            return suggestion
5✔
420

421
    for suggestion in suggestions:
5✔
422
        for column_name in column_names:
5✔
423
            if suggestion.lower() == column_name.lower():
5✔
424
                return column_name
5✔
425

426
    return None
×
427

428

429
def guess_node_id_column_name(
5✔
430
    nodes_table: Union["pa.Table", "KiaraTable", "Value"],
431
    suggestions: Union[None, List[str]] = None,
432
) -> Union[str, None]:
433
    if suggestions is None:
5✔
434
        suggestions = NODE_ID_ALIAS_NAMES
×
435
    return guess_column_name(table=nodes_table, suggestions=suggestions)
5✔
436

437

438
def guess_node_label_column_name(
5✔
439
    nodes_table: Union["pa.Table", "KiaraTable", "Value"],
440
    suggestions: Union[None, List[str]] = None,
441
) -> Union[str, None]:
442
    if suggestions is None:
5✔
443
        suggestions = LABEL_ALIAS_NAMES
×
444
    return guess_column_name(table=nodes_table, suggestions=suggestions)
5✔
445

446

447
def guess_source_column_name(
5✔
448
    edges_table: Union["pa.Table", "KiaraTable", "Value"],
449
    suggestions: Union[None, List[str]] = None,
450
) -> Union[str, None]:
451
    if suggestions is None:
5✔
452
        suggestions = SOURCE_COLUMN_ALIAS_NAMES
×
453
    return guess_column_name(table=edges_table, suggestions=suggestions)
5✔
454

455

456
def guess_target_column_name(
5✔
457
    edges_table: Union["pa.Table", "KiaraTable", "Value"],
458
    suggestions: Union[None, List[str]] = None,
459
) -> Union[str, None]:
460
    if suggestions is None:
5✔
461
        suggestions = TARGET_COLUMN_ALIAS_NAMES
×
462
    return guess_column_name(table=edges_table, suggestions=suggestions)
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc