• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

DHARPA-Project / kiara_plugin.network_analysis / 15680097829

16 Jun 2025 11:55AM UTC coverage: 60.865% (+0.05%) from 60.815%
15680097829

push

github

makkus
chore: fix linting & mypy issues

63 of 121 branches covered (52.07%)

Branch coverage included in aggregate %.

5 of 18 new or added lines in 3 files covered. (27.78%)

3 existing lines in 2 files now uncovered.

472 of 758 relevant lines covered (62.27%)

3.11 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

30.63
/src/kiara_plugin/network_analysis/utils.py
1
# -*- coding: utf-8 -*-
2
#  Copyright (c) 2022, Markus Binsteiner
3
#
4
#  Mozilla Public License, version 2.0 (see LICENSE or https://www.mozilla.org/en-US/MPL/2.0/)
5
from typing import (
5✔
6
    TYPE_CHECKING,
7
    Any,
8
    Dict,
9
    Hashable,
10
    Iterable,
11
    List,
12
    Tuple,
13
    Union,
14
)
15

16
from kiara.exceptions import KiaraException
5✔
17
from kiara_plugin.network_analysis.defaults import (
5✔
18
    CONNECTIONS_COLUMN_NAME,
19
    CONNECTIONS_MULTI_COLUMN_NAME,
20
    COUNT_DIRECTED_COLUMN_NAME,
21
    COUNT_IDX_DIRECTED_COLUMN_NAME,
22
    COUNT_IDX_UNDIRECTED_COLUMN_NAME,
23
    COUNT_UNDIRECTED_COLUMN_NAME,
24
    EDGE_ID_COLUMN_NAME,
25
    IN_DIRECTED_COLUMN_NAME,
26
    IN_DIRECTED_MULTI_COLUMN_NAME,
27
    LABEL_COLUMN_NAME,
28
    NODE_ID_COLUMN_NAME,
29
    OUT_DIRECTED_COLUMN_NAME,
30
    OUT_DIRECTED_MULTI_COLUMN_NAME,
31
    SOURCE_COLUMN_NAME,
32
    TARGET_COLUMN_NAME,
33
    UNWEIGHTED_DEGREE_CENTRALITY_COLUMN_NAME,
34
    UNWEIGHTED_DEGREE_CENTRALITY_MULTI_COLUMN_NAME,
35
)
36

37
if TYPE_CHECKING:
5✔
38
    import networkx as nx
×
39
    import polars as pl
×
40
    import pyarrow as pa
×
41
    from sqlalchemy import MetaData, Table  # noqa
×
42

43

44
def extract_networkx_nodes_as_table(
5✔
45
    graph: "nx.Graph",
46
    label_attr_name: Union[str, None, Iterable[str]] = None,
47
    ignore_attributes: Union[None, Iterable[str]] = None,
48
) -> Tuple["pa.Table", Dict[Hashable, int]]:
49
    """Extract the nodes of a networkx graph as a pyarrow table.
50

51
    Arguments:
52
        graph: the networkx graph
53
        label_attr_name: the name of the node attribute that should be used as label. If None, the node id is used.
54
        ignore_attributes: a list of node attributes that should be ignored and not added to the table
55

56
    Returns:
57
        a tuple with the table and a map containing the original node id as key and the newly created internal node id (int) as value
58
    """
59
    # adapted from networx code
60
    # License: 3-clause BSD license
61
    # Copyright (C) 2004-2022, NetworkX Developers
62

63
    import pyarrow as pa
×
64

65
    # nan = float("nan")
66

67
    nodes: Dict[str, List[Any]] = {
×
68
        NODE_ID_COLUMN_NAME: [],
69
        LABEL_COLUMN_NAME: [],
70
    }
71
    nodes_map = {}
×
72

73
    for i, (node_id, node_data) in enumerate(graph.nodes(data=True)):
×
74
        nodes[NODE_ID_COLUMN_NAME].append(i)
×
75
        if label_attr_name is None:
×
76
            nodes[LABEL_COLUMN_NAME].append(str(node_id))
×
77
        elif isinstance(label_attr_name, str):
×
78
            label = node_data.get(label_attr_name, None)
×
79
            if label:
×
80
                nodes[LABEL_COLUMN_NAME].append(str(label))
×
81
            else:
82
                nodes[LABEL_COLUMN_NAME].append(str(node_id))
×
83
        else:
84
            label_final = None
×
85
            for label in label_attr_name:
×
86
                label_final = node_data.get(label, None)
×
87
                if label_final:
×
88
                    break
×
89
            if not label_final:
×
90
                label_final = node_id
×
91
            nodes[LABEL_COLUMN_NAME].append(str(label_final))
×
92

93
        nodes_map[node_id] = i
×
94
        for k in node_data.keys():
×
95
            if ignore_attributes and k in ignore_attributes:
×
96
                continue
×
97

98
            if k.startswith("_"):
×
99
                raise KiaraException(
×
100
                    "Graph contains node column name starting with '_'. This is reserved for internal use, and not allowed."
101
                )
102

103
            v = node_data.get(k, None)
×
104
            nodes.setdefault(k, []).append(v)
×
105

106
    nodes_table = pa.Table.from_pydict(mapping=nodes)
×
107

108
    return nodes_table, nodes_map
×
109

110

111
def extract_networkx_edges_as_table(
5✔
112
    graph: "nx.Graph", node_id_map: Dict[Hashable, int]
113
) -> "pa.Table":
114
    """Extract the edges of this graph as a pyarrow table.
115

116
    The provided `node_id_map` might be modified if a node id is not yet in the map.
117

118
    Args:
119
        graph: The graph to extract edges from.
120
        node_id_map: A mapping from (original) node ids to (kiara-internal) (integer) node-ids.
121
    """
122

123
    # adapted from networx code
124
    # License: 3-clause BSD license
125
    # Copyright (C) 2004-2022, NetworkX Developers
126

127
    import pyarrow as pa
×
128

129
    if node_id_map is None:
×
130
        node_id_map = {}
×
131

132
    # nan = float("nan")
133

134
    max_node_id = max(node_id_map.values())  # TODO: could we just use len(node_id_map)?
×
135
    edge_columns: Dict[str, List[int]] = {
×
136
        SOURCE_COLUMN_NAME: [],
137
        TARGET_COLUMN_NAME: [],
138
    }
139

140
    for source, target, edge_data in graph.edges(data=True):
×
141
        if source not in node_id_map.keys():
×
142
            max_node_id += 1
×
143
            node_id_map[source] = max_node_id
×
144
        if target not in node_id_map.keys():
×
145
            max_node_id += 1
×
146
            node_id_map[target] = max_node_id
×
147

148
        edge_columns[SOURCE_COLUMN_NAME].append(node_id_map[source])
×
149
        edge_columns[TARGET_COLUMN_NAME].append(node_id_map[target])
×
150

151
        for k in edge_data.keys():
×
152
            if k.startswith("_"):
×
153
                raise KiaraException(
×
154
                    "Graph contains edge column name starting with '_'. This is reserved for internal use, and not allowed."
155
                )
156

157
            v = edge_data.get(k, None)
×
NEW
158
            edge_columns.setdefault(k, []).append(v)  # type: ignore
×
159

160
    edges_table = pa.Table.from_pydict(mapping=edge_columns)
×
161

162
    return edges_table
×
163

164

165
def augment_nodes_table_with_connection_counts(
5✔
166
    nodes_table: Union["pa.Table", "pl.DataFrame"],
167
    edges_table: Union["pa.Table", "pl.DataFrame"],
168
) -> "pa.Table":
169
    import duckdb
5✔
170

171
    try:
5✔
172
        nodes_column_names = nodes_table.column_names  # type: ignore
5✔
173
    except Exception:
×
174
        nodes_column_names = nodes_table.columns  # type: ignore
×
175

176
    node_attr_columns = [x for x in nodes_column_names if not x.startswith("_")]
5✔
177
    if node_attr_columns:
5✔
178
        other_columns = ", " + ", ".join(node_attr_columns)
5✔
179
    else:
180
        other_columns = ""
×
181

182
    # we can avoid 'COUNT(*)' calls in the following  query
183
    nodes_table_rows = len(nodes_table)
5✔
184

185
    query = f"""
5✔
186
    SELECT
187
         {NODE_ID_COLUMN_NAME},
188
         {LABEL_COLUMN_NAME},
189
         COALESCE(e1.{IN_DIRECTED_COLUMN_NAME}, 0) + COALESCE(e3.{OUT_DIRECTED_COLUMN_NAME}, 0) as {CONNECTIONS_COLUMN_NAME},
190
         (COALESCE(e1._in_edges, 0) + COALESCE(e3._out_edges, 0)) / {nodes_table_rows} AS _degree_centrality,
191
         COALESCE(e2.{IN_DIRECTED_MULTI_COLUMN_NAME}, 0) + COALESCE(e4.{OUT_DIRECTED_MULTI_COLUMN_NAME}, 0) as {CONNECTIONS_MULTI_COLUMN_NAME},
192
         COALESCE(e1.{IN_DIRECTED_COLUMN_NAME}, 0) as {IN_DIRECTED_COLUMN_NAME},
193
         COALESCE(e2.{IN_DIRECTED_MULTI_COLUMN_NAME}, 0) as {IN_DIRECTED_MULTI_COLUMN_NAME},
194
         COALESCE(e3.{OUT_DIRECTED_COLUMN_NAME}, 0) as {OUT_DIRECTED_COLUMN_NAME},
195
         COALESCE(e4.{OUT_DIRECTED_MULTI_COLUMN_NAME}, 0) as {OUT_DIRECTED_MULTI_COLUMN_NAME}
196
         {other_columns}
197
         FROM nodes_table n
198
         left join
199
           (SELECT {TARGET_COLUMN_NAME}, {COUNT_IDX_DIRECTED_COLUMN_NAME}, COUNT(*) as {IN_DIRECTED_COLUMN_NAME} from edges_table GROUP BY {TARGET_COLUMN_NAME}, {COUNT_IDX_DIRECTED_COLUMN_NAME}) e1
200
           on n.{NODE_ID_COLUMN_NAME} = e1.{TARGET_COLUMN_NAME} and e1.{COUNT_IDX_DIRECTED_COLUMN_NAME} = 1
201
         left join
202
           (SELECT {TARGET_COLUMN_NAME}, COUNT(*) as {IN_DIRECTED_MULTI_COLUMN_NAME} from edges_table GROUP BY {TARGET_COLUMN_NAME}) e2
203
           on n.{NODE_ID_COLUMN_NAME} = e2.{TARGET_COLUMN_NAME}
204
         left join
205
           (SELECT {SOURCE_COLUMN_NAME}, {COUNT_IDX_DIRECTED_COLUMN_NAME}, COUNT(*) as {OUT_DIRECTED_COLUMN_NAME} from edges_table GROUP BY {SOURCE_COLUMN_NAME}, {COUNT_IDX_DIRECTED_COLUMN_NAME}) e3
206
           on n.{NODE_ID_COLUMN_NAME} = e3.{SOURCE_COLUMN_NAME} and e3.{COUNT_IDX_DIRECTED_COLUMN_NAME} = 1
207
         left join
208
           (SELECT {SOURCE_COLUMN_NAME}, COUNT(*) as {OUT_DIRECTED_MULTI_COLUMN_NAME} from edges_table GROUP BY {SOURCE_COLUMN_NAME}) e4
209
           on n.{NODE_ID_COLUMN_NAME} = e4.{SOURCE_COLUMN_NAME}
210
        ORDER BY {NODE_ID_COLUMN_NAME}
211
    """
212

213
    nodes_table_result = duckdb.sql(query)  # noqa
5✔
214

215
    centrality_query = f"""
5✔
216
    SELECT
217
         {NODE_ID_COLUMN_NAME},
218
         {LABEL_COLUMN_NAME},
219
         {CONNECTIONS_COLUMN_NAME},
220
         {CONNECTIONS_COLUMN_NAME} / (SELECT COUNT(*) FROM nodes_table_result) AS {UNWEIGHTED_DEGREE_CENTRALITY_COLUMN_NAME},
221
         {CONNECTIONS_MULTI_COLUMN_NAME},
222
         {CONNECTIONS_MULTI_COLUMN_NAME} / (SELECT COUNT(*) FROM nodes_table_result) AS {UNWEIGHTED_DEGREE_CENTRALITY_MULTI_COLUMN_NAME},
223
         {IN_DIRECTED_COLUMN_NAME},
224
         {IN_DIRECTED_MULTI_COLUMN_NAME},
225
         {OUT_DIRECTED_COLUMN_NAME},
226
         {OUT_DIRECTED_MULTI_COLUMN_NAME}
227
         {other_columns}
228
    FROM nodes_table_result
229
    """
230

231
    result = duckdb.sql(centrality_query)
5✔
232

233
    nodes_table_augmented = result.arrow()
5✔
234
    return nodes_table_augmented
5✔
235

236

237
def augment_edges_table_with_id_and_weights(
5✔
238
    edges_table: Union["pa.Table", "pl.DataFrame"],
239
) -> "pa.Table":
240
    """Augment the edges table with additional pre-computed columns for directed and undirected weights.."""
241

242
    import duckdb
5✔
243

244
    try:
5✔
245
        column_names = edges_table.column_names  # type: ignore
5✔
246
    except Exception:
×
247
        column_names = edges_table.columns  # type: ignore
×
248

249
    edge_attr_columns = [x for x in column_names if not x.startswith("_")]
5✔
250
    if edge_attr_columns:
5✔
251
        other_columns = ", " + ", ".join(edge_attr_columns)
5✔
252
    else:
253
        other_columns = ""
×
254

255
    query = f"""
5✔
256
    SELECT
257
      ROW_NUMBER() OVER () -1 as {EDGE_ID_COLUMN_NAME},
258
      {SOURCE_COLUMN_NAME},
259
      {TARGET_COLUMN_NAME},
260
      COUNT(*) OVER (PARTITION BY {SOURCE_COLUMN_NAME}, {TARGET_COLUMN_NAME}) as {COUNT_DIRECTED_COLUMN_NAME},
261
      ROW_NUMBER(*) OVER (PARTITION BY {SOURCE_COLUMN_NAME}, {TARGET_COLUMN_NAME}) as {COUNT_IDX_DIRECTED_COLUMN_NAME},
262
      COUNT(*) OVER (PARTITION BY LEAST({SOURCE_COLUMN_NAME}, {TARGET_COLUMN_NAME}), GREATEST({SOURCE_COLUMN_NAME}, {TARGET_COLUMN_NAME})) as {COUNT_UNDIRECTED_COLUMN_NAME},
263
      ROW_NUMBER(*) OVER (PARTITION BY LEAST({SOURCE_COLUMN_NAME}, {TARGET_COLUMN_NAME}), GREATEST({SOURCE_COLUMN_NAME}, {TARGET_COLUMN_NAME})) as {COUNT_IDX_UNDIRECTED_COLUMN_NAME}
264
      {other_columns}
265
    FROM edges_table"""
266

267
    result = duckdb.sql(query)
5✔
268
    edges_table_augmented = result.arrow()
5✔
269

270
    return edges_table_augmented
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc