• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Qiskit / rustworkx / 12817162714

16 Jan 2025 08:34PM UTC coverage: 95.811% (-0.02%) from 95.83%
12817162714

push

github

web-flow
Bump PyO3 and rust-numpy to 0.23 (#1364)

* Bump PyO3 and rust-numpy to 0.23

This commit bumps the version of pyo3 and rust-numpy used by qiskit to
the latest release 0.23. The largest change by volume of code is the
deprecation of all the *_bound() methods. These are just warnings but
they would be fatal to our CI so it needs to be updated. The larger
functional change that required updating the code is the change in the
traits around converting to Python objects.

As a side effect of this change it lets us unify the hashbrown versions
installed because we can update indexmap.

* Remove unused features section from Cargo.toml

274 of 282 new or added lines in 17 files covered. (97.16%)

2 existing lines in 1 file now uncovered.

18343 of 19145 relevant lines covered (95.81%)

1471653.02 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.29
/src/digraph.rs
1
// Licensed under the Apache License, Version 2.0 (the "License"); you may
2
// not use this file except in compliance with the License. You may obtain
3
// a copy of the License at
4
//
5
//     http://www.apache.org/licenses/LICENSE-2.0
6
//
7
// Unless required by applicable law or agreed to in writing, software
8
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10
// License for the specific language governing permissions and limitations
11
// under the License.
12

13
#![allow(clippy::borrow_as_ptr, clippy::redundant_closure)]
14

15
use std::cmp;
16
use std::cmp::Ordering;
17
use std::collections::BTreeMap;
18

19
use std::fs::File;
20
use std::io::prelude::*;
21
use std::io::{BufReader, BufWriter};
22
use std::str;
23

24
use hashbrown::{HashMap, HashSet};
25

26
use rustworkx_core::dictmap::*;
27
use rustworkx_core::graph_ext::*;
28

29
use smallvec::SmallVec;
30

31
use pyo3::exceptions::PyIndexError;
32
use pyo3::gc::PyVisit;
33
use pyo3::prelude::*;
34
use pyo3::types::{IntoPyDict, PyBool, PyDict, PyList, PyString, PyTuple, PyType};
35
use pyo3::IntoPyObjectExt;
36
use pyo3::PyTraverseError;
37
use pyo3::Python;
38

39
use ndarray::prelude::*;
40
use num_traits::Zero;
41
use numpy::Complex64;
42
use numpy::PyReadonlyArray2;
43

44
use petgraph::algo;
45
use petgraph::graph::{EdgeIndex, NodeIndex};
46
use petgraph::prelude::*;
47

48
use crate::RxPyResult;
49
use petgraph::visit::{
50
    EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered,
51
    NodeIndexable, Visitable,
52
};
53

54
use super::dot_utils::build_dot;
55
use super::iterators::{
56
    EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, NodeMap, WeightedEdgeList,
57
};
58
use super::{
59
    find_node_by_weight, generic_class_getitem, weight_callable, DAGHasCycle, DAGWouldCycle, IsNan,
60
    NoEdgeBetweenNodes, NoSuitableNeighbors, NodesRemoved, StablePyGraph,
61
};
62

63
use super::dag_algo::is_directed_acyclic_graph;
64

65
/// A class for creating directed graphs
66
///
67
/// The ``PyDiGraph`` class is used to create a directed graph. It can be a
68
/// multigraph (have multiple edges between nodes). Each node and edge
69
/// (although rarely used for edges) is indexed by an integer id. These ids
70
/// are stable for the lifetime of the graph object and on node or edge
71
/// deletions you can have holes in the list of indices for the graph.
72
/// Node indices will be reused on additions after removal. For example:
73
///
74
/// .. jupyter-execute::
75
///
76
///        import rustworkx as rx
77
///
78
///        graph = rx.PyDiGraph()
79
///        graph.add_nodes_from(list(range(5)))
80
///        graph.add_nodes_from(list(range(2)))
81
///        graph.remove_node(2)
82
///        print("After deletion:", graph.node_indices())
83
///        res_manual = graph.add_parent(6, None, None)
84
///        print("After adding a new node:", graph.node_indices())
85
///
86
/// Additionally, each node and edge contains an arbitrary Python object as a
87
/// weight/data payload. You can use the index for access to the data payload
88
/// as in the following example:
89
///
90
/// .. jupyter-execute::
91
///
92
///     import rustworkx as rx
93
///
94
///     graph = rx.PyDiGraph()
95
///     data_payload = "An arbitrary Python object"
96
///     node_index = graph.add_node(data_payload)
97
///     print("Node Index: %s" % node_index)
98
///     print(graph[node_index])
99
///
100
/// The PyDiGraph implements the Python mapping protocol for nodes so in
101
/// addition to access you can also update the data payload with:
102
///
103
/// .. jupyter-execute::
104
///
105
///     import rustworkx as rx
106
///
107
///     graph = rx.PyDiGraph()
108
///     data_payload = "An arbitrary Python object"
109
///     node_index = graph.add_node(data_payload)
110
///     graph[node_index] = "New Payload"
111
///     print("Node Index: %s" % node_index)
112
///     print(graph[node_index])
113
///
114
/// The PyDiGraph class has an option for real time cycle checking which can
115
/// be used to ensure any edges added to the graph does not introduce a cycle.
116
/// By default the real time cycle checking feature is disabled for performance,
117
/// however you can enable it by setting the ``check_cycle`` attribute to True.
118
/// For example::
119
///
120
///     import rustworkx as rx
121
///     dag = rx.PyDiGraph()
122
///     dag.check_cycle = True
123
///
124
/// or at object creation::
125
///
126
///     import rustworkx as rx
127
///     dag = rx.PyDiGraph(check_cycle=True)
128
///
129
/// With check_cycle set to true any calls to :meth:`PyDiGraph.add_edge` will
130
/// ensure that no cycles are added, ensuring that the PyDiGraph class truly
131
/// represents a directed acyclic graph. Do note that this cycle checking on
132
/// :meth:`~PyDiGraph.add_edge`, :meth:`~PyDiGraph.add_edges_from`,
133
/// :meth:`~PyDiGraph.add_edges_from_no_data`,
134
/// :meth:`~PyDiGraph.extend_from_edge_list`,  and
135
/// :meth:`~PyDiGraph.extend_from_weighted_edge_list` comes with a performance
136
/// penalty that grows as the graph does. If you're adding a node and edge at
137
/// the same time leveraging :meth:`PyDiGraph.add_child` or
138
/// :meth:`PyDiGraph.add_parent` will avoid this overhead.
139
///
140
/// By default a ``PyDiGraph`` is a multigraph (meaning there can be parallel
141
/// edges between nodes) however this can be disabled by setting the
142
/// ``multigraph`` kwarg to ``False`` when calling the ``PyDiGraph``
143
/// constructor. For example::
144
///
145
///     import rustworkx as rx
146
///     graph = rx.PyDiGraph(multigraph=False)
147
///
148
/// This can only be set at ``PyDiGraph`` initialization and not adjusted after
149
/// creation. When :attr:`~rustworkx.PyDiGraph.multigraph` is set to ``False``
150
/// if a method call is made that would add a parallel edge it will instead
151
/// update the existing edge's weight/data payload.
152
///
153
/// Each ``PyDiGraph`` object has an :attr:`~.PyDiGraph.attrs` attribute which is
154
/// used to contain additional attributes/metadata of the graph instance. By
155
/// default this is set to ``None`` but can optionally be specified by using the
156
/// ``attrs`` keyword argument when constructing a new graph::
157
///
158
///     graph = rustworkx.PyDiGraph(attrs=dict(source_path='/tmp/graph.csv'))
159
///
160
/// This attribute can be set to any Python object. Additionally, you can access
161
/// and modify this attribute after creating an object. For example::
162
///
163
///     source_path = graph.attrs
164
///     graph.attrs = {'new_path': '/tmp/new.csv', 'old_path': source_path}
165
///
166
/// The maximum number of nodes and edges allowed on a ``PyGraph`` object is
167
/// :math:`2^{32} - 1` (4,294,967,294) each. Attempting to add more nodes or
168
/// edges than this will result in an exception being raised.
169
///
170
/// :param bool check_cycle: When this is set to ``True`` the created
171
///     ``PyDiGraph`` has runtime cycle detection enabled.
172
/// :param bool multigraph: When this is set to ``False`` the created
173
///     ``PyDiGraph`` object will not be a multigraph. When ``False`` if a
174
///     method call is made that would add parallel edges the the weight/weight
175
///     from that method call will be used to update the existing edge in place.
176
/// :param attrs: An optional attributes payload to assign to the
177
///     :attr:`~.PyDiGraph.attrs` attribute. This can be any Python object. If
178
///     it is not specified :attr:`~.PyDiGraph.attrs` will be set to ``None``.
179
/// :param int node_count_hint: An optional hint that will allocate with enough capacity to store this
180
///     many nodes before needing to grow.  This does not prepopulate any nodes with data, it is
181
///     only a potential performance optimization if the complete size of the graph is known in
182
///     advance.
183
/// :param int edge_count_hint: An optional hint that will allocate enough capacity to store this
184
///     many edges before needing to grow.  This does not prepopulate any edges with data, it is
185
///     only a potential performance optimization if the complete size of the graph is known in
186
///     advance.
187
#[pyclass(mapping, module = "rustworkx", subclass)]
2,972✔
188
#[derive(Clone)]
189
pub struct PyDiGraph {
190
    pub graph: StablePyGraph<Directed>,
191
    pub cycle_state: algo::DfsSpace<NodeIndex, <StablePyGraph<Directed> as Visitable>::Map>,
192
    pub check_cycle: bool,
193
    pub node_removed: bool,
194
    pub multigraph: bool,
195
    #[pyo3(get, set)]
196
    pub attrs: PyObject,
20✔
197
}
198

199
impl GraphBase for PyDiGraph {
200
    type NodeId = NodeIndex;
201
    type EdgeId = EdgeIndex;
202
}
203

204
impl NodesRemoved for &PyDiGraph {
205
    fn nodes_removed(&self) -> bool {
×
206
        self.node_removed
×
207
    }
×
208
}
209

210
impl NodeCount for PyDiGraph {
211
    fn node_count(&self) -> usize {
3,520✔
212
        self.graph.node_count()
3,520✔
213
    }
3,520✔
214
}
215

216
// Rust side only PyDiGraph methods
217
impl PyDiGraph {
218
    fn add_edge_no_cycle_check(
2,026,554✔
219
        &mut self,
2,026,554✔
220
        p_index: NodeIndex,
2,026,554✔
221
        c_index: NodeIndex,
2,026,554✔
222
        edge: PyObject,
2,026,554✔
223
    ) -> usize {
2,026,554✔
224
        if !self.multigraph {
2,026,554✔
225
            let exists = self.graph.find_edge(p_index, c_index);
138✔
226
            if let Some(index) = exists {
138✔
227
                let edge_weight = self.graph.edge_weight_mut(index).unwrap();
16✔
228
                *edge_weight = edge;
16✔
229
                return index.index();
16✔
230
            }
122✔
231
        }
2,026,416✔
232
        let edge = self.graph.add_edge(p_index, c_index, edge);
2,026,538✔
233
        edge.index()
2,026,538✔
234
    }
2,026,554✔
235

236
    fn _add_edge(
2,026,566✔
237
        &mut self,
2,026,566✔
238
        p_index: NodeIndex,
2,026,566✔
239
        c_index: NodeIndex,
2,026,566✔
240
        edge: PyObject,
2,026,566✔
241
    ) -> PyResult<usize> {
2,026,566✔
242
        // Only check for cycles if instance attribute is set to true
2,026,566✔
243
        if self.check_cycle {
2,026,566✔
244
            // Only check for a cycle (by running has_path_connecting) if
245
            // the new edge could potentially add a cycle
246
            let cycle_check_required = is_cycle_check_required(self, p_index, c_index);
26✔
247
            let state = Some(&mut self.cycle_state);
26✔
248
            if cycle_check_required
26✔
249
                && algo::has_path_connecting(&self.graph, c_index, p_index, state)
14✔
250
            {
251
                return Err(DAGWouldCycle::new_err("Adding an edge would cycle"));
12✔
252
            }
14✔
253
        }
2,026,540✔
254
        Ok(self.add_edge_no_cycle_check(p_index, c_index, edge))
2,026,554✔
255
    }
2,026,566✔
256

257
    fn insert_between(
44✔
258
        &mut self,
44✔
259
        py: Python,
44✔
260
        node: usize,
44✔
261
        node_between: usize,
44✔
262
        direction: bool,
44✔
263
    ) -> PyResult<()> {
44✔
264
        let dir = if direction {
44✔
265
            petgraph::Direction::Outgoing
22✔
266
        } else {
267
            petgraph::Direction::Incoming
22✔
268
        };
269
        let index = NodeIndex::new(node);
44✔
270
        let node_between_index = NodeIndex::new(node_between);
44✔
271
        let edges: Vec<(NodeIndex, EdgeIndex, PyObject)> = self
44✔
272
            .graph
44✔
273
            .edges_directed(node_between_index, dir)
44✔
274
            .map(|edge| {
44✔
275
                if direction {
36✔
276
                    (edge.target(), edge.id(), edge.weight().clone_ref(py))
18✔
277
                } else {
278
                    (edge.source(), edge.id(), edge.weight().clone_ref(py))
18✔
279
                }
280
            })
44✔
281
            .collect::<Vec<(NodeIndex, EdgeIndex, PyObject)>>();
44✔
282
        for (other_index, edge_index, weight) in edges {
80✔
283
            if direction {
36✔
284
                self._add_edge(node_between_index, index, weight.clone_ref(py))?;
18✔
285
                self._add_edge(index, other_index, weight.clone_ref(py))?;
18✔
286
            } else {
287
                self._add_edge(other_index, index, weight.clone_ref(py))?;
18✔
288
                self._add_edge(index, node_between_index, weight.clone_ref(py))?;
18✔
289
            }
290
            self.graph.remove_edge(edge_index);
36✔
291
        }
292
        Ok(())
44✔
293
    }
44✔
294
}
295

296
#[pymethods]
817,852✔
297
impl PyDiGraph {
298
    #[new]
299
    #[pyo3(signature=(/, check_cycle=false, multigraph=true, attrs=None, *, node_count_hint=None, edge_count_hint=None))]
300
    fn new(
2,160✔
301
        py: Python,
2,160✔
302
        check_cycle: bool,
2,160✔
303
        multigraph: bool,
2,160✔
304
        attrs: Option<PyObject>,
2,160✔
305
        node_count_hint: Option<usize>,
2,160✔
306
        edge_count_hint: Option<usize>,
2,160✔
307
    ) -> Self {
2,160✔
308
        PyDiGraph {
2,160✔
309
            graph: StablePyGraph::<Directed>::with_capacity(
2,160✔
310
                node_count_hint.unwrap_or_default(),
2,160✔
311
                edge_count_hint.unwrap_or_default(),
2,160✔
312
            ),
2,160✔
313
            cycle_state: algo::DfsSpace::default(),
2,160✔
314
            check_cycle,
2,160✔
315
            node_removed: false,
2,160✔
316
            multigraph,
2,160✔
317
            attrs: attrs.unwrap_or_else(|| py.None()),
2,160✔
318
        }
2,160✔
319
    }
2,160✔
320

321
    fn __getnewargs_ex__<'py>(
34✔
322
        &self,
34✔
323
        py: Python<'py>,
34✔
324
    ) -> PyResult<(Bound<'py, PyTuple>, Bound<'py, PyDict>)> {
34✔
325
        Ok((
34✔
326
            (self.check_cycle, self.multigraph, self.attrs.clone_ref(py)).into_pyobject(py)?,
34✔
327
            [
34✔
328
                ("node_count_hint", self.graph.node_bound()),
34✔
329
                ("edge_count_hint", self.graph.edge_bound()),
34✔
330
            ]
34✔
331
            .into_py_dict(py)?,
34✔
332
        ))
333
    }
34✔
334

335
    fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
34✔
336
        let mut nodes: Vec<PyObject> = Vec::with_capacity(self.graph.node_bound());
34✔
337
        let mut edges: Vec<PyObject> = Vec::with_capacity(self.graph.edge_bound());
34✔
338

339
        // save nodes to a list along with its index
340
        for node_idx in self.graph.node_indices() {
124✔
341
            let node_data = self.graph.node_weight(node_idx).unwrap();
124✔
342
            nodes.push((node_idx.index(), node_data).into_py_any(py)?);
124✔
343
        }
344

345
        // edges are saved with none (deleted edges) instead of their index to save space
346
        for i in 0..self.graph.edge_bound() {
114✔
347
            let idx = EdgeIndex::new(i);
114✔
348
            let edge = match self.graph.edge_weight(idx) {
114✔
349
                Some(edge_w) => {
98✔
350
                    let endpoints = self.graph.edge_endpoints(idx).unwrap();
98✔
351
                    (endpoints.0.index(), endpoints.1.index(), edge_w).into_py_any(py)?
98✔
352
                }
353
                None => py.None(),
16✔
354
            };
355
            edges.push(edge);
114✔
356
        }
357
        let out_dict = PyDict::new(py);
34✔
358
        let nodes_lst = PyList::new(py, nodes)?;
34✔
359
        let edges_lst = PyList::new(py, edges)?;
34✔
360
        out_dict.set_item("nodes", nodes_lst)?;
34✔
361
        out_dict.set_item("edges", edges_lst)?;
34✔
362
        out_dict.set_item("nodes_removed", self.node_removed)?;
34✔
363
        Ok(out_dict.into())
34✔
364
    }
34✔
365

366
    fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
34✔
367
        let dict_state = state.downcast_bound::<PyDict>(py)?;
34✔
368
        let binding = dict_state.get_item("nodes")?.unwrap();
34✔
369
        let nodes_lst = binding.downcast::<PyList>()?;
34✔
370
        let binding = dict_state.get_item("edges")?.unwrap();
34✔
371
        let edges_lst = binding.downcast::<PyList>()?;
34✔
372
        self.graph = StablePyGraph::<Directed>::new();
34✔
373
        let dict_state = state.downcast_bound::<PyDict>(py)?;
34✔
374
        self.node_removed = dict_state
34✔
375
            .get_item("nodes_removed")?
34✔
376
            .unwrap()
34✔
377
            .downcast::<PyBool>()?
34✔
378
            .extract()?;
34✔
379

380
        // graph is empty, stop early
381
        if nodes_lst.is_empty() {
34✔
382
            return Ok(());
12✔
383
        }
22✔
384

22✔
385
        if !self.node_removed {
22✔
386
            for item in nodes_lst.iter() {
46✔
387
                let node_w = item
46✔
388
                    .downcast::<PyTuple>()
46✔
389
                    .unwrap()
46✔
390
                    .get_item(1)
46✔
391
                    .unwrap()
46✔
392
                    .extract()
46✔
393
                    .unwrap();
46✔
394
                self.graph.add_node(node_w);
46✔
395
            }
46✔
396
        } else if nodes_lst.len() == 1 {
12✔
397
            // graph has only one node, handle logic here to save one if in the loop later
398
            let binding = nodes_lst.get_item(0).unwrap();
×
399
            let item = binding.downcast::<PyTuple>().unwrap();
×
400
            let node_idx: usize = item.get_item(0).unwrap().extract().unwrap();
×
401
            let node_w = item.get_item(1).unwrap().extract().unwrap();
×
402

403
            for _i in 0..node_idx {
×
404
                self.graph.add_node(py.None());
×
405
            }
×
406
            self.graph.add_node(node_w);
×
407
            for i in 0..node_idx {
×
408
                self.graph.remove_node(NodeIndex::new(i));
×
409
            }
×
410
        } else {
411
            let binding = nodes_lst.get_item(nodes_lst.len() - 1).unwrap();
12✔
412
            let last_item = binding.downcast::<PyTuple>().unwrap();
12✔
413

12✔
414
            // list of temporary nodes that will be removed later to re-create holes
12✔
415
            let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap();
12✔
416
            let mut tmp_nodes: Vec<NodeIndex> =
12✔
417
                Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len());
12✔
418

419
            for item in nodes_lst {
90✔
420
                let item = item.downcast::<PyTuple>().unwrap();
78✔
421
                let next_index: usize = item.get_item(0).unwrap().extract().unwrap();
78✔
422
                let weight: PyObject = item.get_item(1).unwrap().extract().unwrap();
78✔
423
                while next_index > self.graph.node_bound() {
98✔
424
                    // node does not exist
20✔
425
                    let tmp_node = self.graph.add_node(py.None());
20✔
426
                    tmp_nodes.push(tmp_node);
20✔
427
                }
20✔
428
                // add node to the graph, and update the next available node index
429
                self.graph.add_node(weight);
78✔
430
            }
431
            // Remove any temporary nodes we added
432
            for tmp_node in tmp_nodes {
32✔
433
                self.graph.remove_node(tmp_node);
20✔
434
            }
20✔
435
        }
436

437
        // to ensure O(1) on edge deletion, use a temporary node to store missing edges
438
        let tmp_node = self.graph.add_node(py.None());
22✔
439

440
        for item in edges_lst {
136✔
441
            if item.is_none() {
114✔
442
                // add a temporary edge that will be deleted later to re-create the hole
16✔
443
                self.graph.add_edge(tmp_node, tmp_node, py.None());
16✔
444
            } else {
98✔
445
                let triple = item.downcast::<PyTuple>().unwrap();
98✔
446
                let edge_p: usize = triple.get_item(0).unwrap().extract().unwrap();
98✔
447
                let edge_c: usize = triple.get_item(1).unwrap().extract().unwrap();
98✔
448
                let edge_w = triple.get_item(2).unwrap().extract().unwrap();
98✔
449
                self.graph
98✔
450
                    .add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w);
98✔
451
            }
98✔
452
        }
453

454
        // remove the temporary node will remove all deleted edges in bulk,
455
        // the cost is equal to the number of edges
456
        self.graph.remove_node(tmp_node);
22✔
457

22✔
458
        Ok(())
22✔
459
    }
34✔
460

461
    /// Whether cycle checking is enabled for the DiGraph/DAG.
462
    ///
463
    /// If set to ``True`` adding new edges that would introduce a cycle
464
    /// will raise a :class:`DAGWouldCycle` exception.
465
    #[getter]
466
    fn get_check_cycle(&self) -> bool {
8✔
467
        self.check_cycle
8✔
468
    }
8✔
469

470
    #[setter]
471
    fn set_check_cycle(&mut self, value: bool) -> PyResult<()> {
14✔
472
        if !self.check_cycle && value && !is_directed_acyclic_graph(self) {
14✔
473
            return Err(DAGHasCycle::new_err("PyDiGraph object has a cycle"));
2✔
474
        }
12✔
475
        self.check_cycle = value;
12✔
476
        Ok(())
12✔
477
    }
14✔
478

479
    /// Whether the graph is a multigraph (allows multiple edges between
480
    /// nodes) or not
481
    ///
482
    /// If set to ``False`` multiple edges between nodes are not allowed and
483
    /// calls that would add a parallel edge will instead update the existing
484
    /// edge
485
    #[getter]
486
    fn multigraph(&self) -> bool {
18✔
487
        self.multigraph
18✔
488
    }
18✔
489

490
    /// Detect if the graph has parallel edges or not
491
    ///
492
    /// :returns: ``True`` if the graph has parallel edges, otherwise ``False``
493
    /// :rtype: bool
494
    #[pyo3(text_signature = "(self)")]
495
    fn has_parallel_edges(&self) -> bool {
8✔
496
        if !self.multigraph {
8✔
497
            return false;
2✔
498
        }
6✔
499
        self.graph.has_parallel_edges()
6✔
500
    }
8✔
501

502
    /// Clear all nodes and edges
503
    #[pyo3(text_signature = "(self)")]
504
    pub fn clear(&mut self) {
4✔
505
        self.graph.clear();
4✔
506
        self.node_removed = true;
4✔
507
    }
4✔
508

509
    /// Clears all edges, leaves nodes intact
510
    #[pyo3(text_signature = "(self)")]
511
    pub fn clear_edges(&mut self) {
4✔
512
        self.graph.clear_edges();
4✔
513
    }
4✔
514

515
    /// Return the number of nodes in the graph
516
    #[pyo3(text_signature = "(self)")]
517
    pub fn num_nodes(&self) -> usize {
28✔
518
        self.graph.node_count()
28✔
519
    }
28✔
520

521
    /// Return the number of edges in the graph
522
    #[pyo3(text_signature = "(self)")]
523
    pub fn num_edges(&self) -> usize {
34✔
524
        self.graph.edge_count()
34✔
525
    }
34✔
526

527
    /// Return a list of all edge data.
528
    ///
529
    /// :returns: A list of all the edge data objects in the graph
530
    /// :rtype: list
531
    #[pyo3(text_signature = "(self)")]
532
    pub fn edges(&self) -> Vec<&PyObject> {
242✔
533
        self.graph
242✔
534
            .edge_indices()
242✔
535
            .map(|edge| self.graph.edge_weight(edge).unwrap())
8,614✔
536
            .collect()
242✔
537
    }
242✔
538

539
    /// Return a list of all edge indices.
540
    ///
541
    /// :returns: A list of all the edge indices in the graph
542
    /// :rtype: EdgeIndices
543
    #[pyo3(text_signature = "(self)")]
544
    pub fn edge_indices(&self) -> EdgeIndices {
36✔
545
        EdgeIndices {
36✔
546
            edges: self.graph.edge_indices().map(|edge| edge.index()).collect(),
70✔
547
        }
36✔
548
    }
36✔
549

550
    /// Return a list of indices of all directed edges between specified nodes
551
    ///
552
    /// :returns: A list of all the edge indices connecting the specified start and end node
553
    /// :rtype: EdgeIndices
554
    pub fn edge_indices_from_endpoints(&self, node_a: usize, node_b: usize) -> EdgeIndices {
6✔
555
        let node_a_index = NodeIndex::new(node_a);
6✔
556
        let node_b_index = NodeIndex::new(node_b);
6✔
557

6✔
558
        EdgeIndices {
6✔
559
            edges: self
6✔
560
                .graph
6✔
561
                .edges_directed(node_a_index, petgraph::Direction::Outgoing)
6✔
562
                .filter(|edge| edge.target() == node_b_index)
24✔
563
                .map(|edge| edge.id().index())
6✔
564
                .collect(),
6✔
565
        }
6✔
566
    }
6✔
567

568
    /// Return a list of all node data.
569
    ///
570
    /// :returns: A list of all the node data objects in the graph
571
    /// :rtype: list
572
    #[pyo3(text_signature = "(self)")]
573
    pub fn nodes(&self) -> Vec<&PyObject> {
156✔
574
        self.graph
156✔
575
            .node_indices()
156✔
576
            .map(|node| self.graph.node_weight(node).unwrap())
8,274✔
577
            .collect()
156✔
578
    }
156✔
579

580
    /// Return a list of all node indices.
581
    ///
582
    /// :returns: A list of all the node indices in the graph
583
    /// :rtype: NodeIndices
584
    #[pyo3(text_signature = "(self)")]
585
    pub fn node_indices(&self) -> NodeIndices {
208✔
586
        NodeIndices {
208✔
587
            nodes: self.graph.node_indices().map(|node| node.index()).collect(),
9,520✔
588
        }
208✔
589
    }
208✔
590

591
    /// Return a list of all node indices.
592
    ///
593
    /// .. note::
594
    ///
595
    ///     This is identical to :meth:`.node_indices()`, which is the
596
    ///     preferred method to get the node indices in the graph. This
597
    ///     exists for backwards compatibility with earlier releases.
598
    ///
599
    /// :returns: A list of all the node indices in the graph
600
    /// :rtype: NodeIndices
601
    #[pyo3(text_signature = "(self)")]
602
    pub fn node_indexes(&self) -> NodeIndices {
112✔
603
        self.node_indices()
112✔
604
    }
112✔
605

606
    /// Return True if there is a node in the graph.
607
    ///
608
    /// :param int node: The node index to check
609
    ///
610
    /// :returns: True if there is a node false if there is no node
611
    /// :rtype: bool
612
    #[pyo3(text_signature = "(self, node, /)")]
613
    pub fn has_node(&self, node: usize) -> bool {
4✔
614
        let index = NodeIndex::new(node);
4✔
615
        self.graph.contains_node(index)
4✔
616
    }
4✔
617

618
    /// Return True if there is an edge from node_a to node_b.
619
    ///
620
    /// :param int node_a: The source node index to check for an edge
621
    /// :param int node_b: The destination node index to check for an edge
622
    ///
623
    /// :returns: True if there is an edge false if there is no edge
624
    /// :rtype: bool
625
    #[pyo3(text_signature = "(self, node_a, node_b, /)")]
626
    pub fn has_edge(&self, node_a: usize, node_b: usize) -> bool {
170✔
627
        let index_a = NodeIndex::new(node_a);
170✔
628
        let index_b = NodeIndex::new(node_b);
170✔
629
        self.graph.find_edge(index_a, index_b).is_some()
170✔
630
    }
170✔
631

632
    /// Return a list of all the node successor data.
633
    ///
634
    /// :param int node: The index for the node to get the successors for
635
    ///
636
    /// :returns: A list of the node data for all the child neighbor nodes
637
    /// :rtype: list
638
    #[pyo3(text_signature = "(self, node, /)")]
639
    pub fn successors(&self, node: usize) -> Vec<&PyObject> {
6✔
640
        let index = NodeIndex::new(node);
6✔
641
        let children = self
6✔
642
            .graph
6✔
643
            .neighbors_directed(index, petgraph::Direction::Outgoing);
6✔
644
        let mut successors: Vec<&PyObject> = Vec::new();
6✔
645
        let mut used_indices: HashSet<NodeIndex> = HashSet::new();
6✔
646
        for succ in children {
32✔
647
            if !used_indices.contains(&succ) {
26✔
648
                successors.push(self.graph.node_weight(succ).unwrap());
24✔
649
                used_indices.insert(succ);
24✔
650
            }
24✔
651
        }
652
        successors
6✔
653
    }
6✔
654

655
    /// Return a list of all the node predecessor data.
656
    ///
657
    /// :param int node: The index for the node to get the predecessors for
658
    ///
659
    /// :returns: A list of the node data for all the parent neighbor nodes
660
    /// :rtype: list
661
    #[pyo3(text_signature = "(self, node, /)")]
662
    pub fn predecessors(&self, node: usize) -> Vec<&PyObject> {
6✔
663
        let index = NodeIndex::new(node);
6✔
664
        let parents = self
6✔
665
            .graph
6✔
666
            .neighbors_directed(index, petgraph::Direction::Incoming);
6✔
667
        let mut predec: Vec<&PyObject> = Vec::new();
6✔
668
        let mut used_indices: HashSet<NodeIndex> = HashSet::new();
6✔
669
        for pred in parents {
32✔
670
            if !used_indices.contains(&pred) {
26✔
671
                predec.push(self.graph.node_weight(pred).unwrap());
24✔
672
                used_indices.insert(pred);
24✔
673
            }
24✔
674
        }
675
        predec
6✔
676
    }
6✔
677

678
    /// Return a filtered list of successors data such that each
679
    /// node has at least one edge data which matches the filter.
680
    ///
681
    /// :param int node: The index for the node to get the successors for
682
    ///
683
    /// :param filter_fn: The filter function to use for matching nodes. It takes
684
    ///     in one argument, the edge data payload/weight object, and will return a
685
    ///     boolean whether the edge matches the conditions or not. If any edge returns
686
    ///     ``True``, the node will be included.
687
    ///
688
    /// :returns: A list of the node data for all the child neighbor nodes
689
    ///           whose at least one edge matches the filter
690
    /// :rtype: list
691
    #[pyo3(text_signature = "(self, node, filter_fn, /)")]
692
    pub fn find_successors_by_edge(
14✔
693
        &self,
14✔
694
        py: Python,
14✔
695
        node: usize,
14✔
696
        filter_fn: PyObject,
14✔
697
    ) -> PyResult<Vec<&PyObject>> {
14✔
698
        let index = NodeIndex::new(node);
14✔
699
        let mut successors: Vec<&PyObject> = Vec::new();
14✔
700
        let mut used_indices: HashSet<NodeIndex> = HashSet::new();
14✔
701

14✔
702
        let filter_edge = |edge: &PyObject| -> PyResult<bool> {
50✔
703
            let res = filter_fn.call1(py, (edge,))?;
50✔
704
            res.extract(py)
50✔
705
        };
50✔
706

707
        let raw_edges = self
14✔
708
            .graph
14✔
709
            .edges_directed(index, petgraph::Direction::Outgoing);
14✔
710

711
        for edge in raw_edges {
66✔
712
            let succ = edge.target();
52✔
713
            if !used_indices.contains(&succ) {
52✔
714
                let edge_weight = edge.weight();
50✔
715
                if filter_edge(edge_weight)? {
50✔
716
                    used_indices.insert(succ);
26✔
717
                    successors.push(self.graph.node_weight(succ).unwrap());
26✔
718
                }
26✔
719
            }
2✔
720
        }
721
        Ok(successors)
14✔
722
    }
14✔
723

724
    /// Return a filtered list of predecessor data such that each
725
    /// node has at least one edge data which matches the filter.
726
    ///
727
    /// :param int node: The index for the node to get the predecessor for
728
    ///
729
    /// :param filter_fn: The filter function to use for matching nodes. It takes
730
    ///     in one argument, the edge data payload/weight object, and will return a
731
    ///     boolean whether the edge matches the conditions or not. If any edge returns
732
    ///     ``True``, the node will be included.
733
    ///
734
    /// :returns: A list of the node data for all the parent neighbor nodes
735
    ///           whose at least one edge matches the filter
736
    /// :rtype: list
737
    #[pyo3(text_signature = "(self, node, filter_fn, /)")]
738
    pub fn find_predecessors_by_edge(
14✔
739
        &self,
14✔
740
        py: Python,
14✔
741
        node: usize,
14✔
742
        filter_fn: PyObject,
14✔
743
    ) -> PyResult<Vec<&PyObject>> {
14✔
744
        let index = NodeIndex::new(node);
14✔
745
        let mut predec: Vec<&PyObject> = Vec::new();
14✔
746
        let mut used_indices: HashSet<NodeIndex> = HashSet::new();
14✔
747

14✔
748
        let filter_edge = |edge: &PyObject| -> PyResult<bool> {
52✔
749
            let res = filter_fn.call1(py, (edge,))?;
52✔
750
            res.extract(py)
52✔
751
        };
52✔
752

753
        let raw_edges = self
14✔
754
            .graph
14✔
755
            .edges_directed(index, petgraph::Direction::Incoming);
14✔
756

757
        for edge in raw_edges {
66✔
758
            let pred = edge.source();
52✔
759
            if !used_indices.contains(&pred) {
52✔
760
                let edge_weight = edge.weight();
52✔
761
                if filter_edge(edge_weight)? {
52✔
762
                    used_indices.insert(pred);
26✔
763
                    predec.push(self.graph.node_weight(pred).unwrap());
26✔
764
                }
26✔
765
            }
×
766
        }
767
        Ok(predec)
14✔
768
    }
14✔
769

770
    /// Return the edge data for an edge between 2 nodes.
771
    ///
772
    /// :param int node_a: The index for the first node
773
    /// :param int node_b: The index for the second node
774
    ///
775
    /// :returns: The data object set for the edge
776
    /// :raises NoEdgeBetweenNodes: When there is no edge between nodes
777
    #[pyo3(text_signature = "(self, node_a, node_b, /)")]
778
    pub fn get_edge_data(&self, node_a: usize, node_b: usize) -> PyResult<&PyObject> {
76✔
779
        let index_a = NodeIndex::new(node_a);
76✔
780
        let index_b = NodeIndex::new(node_b);
76✔
781
        let edge_index = match self.graph.find_edge(index_a, index_b) {
76✔
782
            Some(edge_index) => edge_index,
72✔
783
            None => return Err(NoEdgeBetweenNodes::new_err("No edge found between nodes")),
4✔
784
        };
785

786
        let data = self.graph.edge_weight(edge_index).unwrap();
72✔
787
        Ok(data)
72✔
788
    }
76✔
789

790
    /// Return the edge data for the edge by its given index
791
    ///
792
    /// :param int edge_index: The edge index to get the data for
793
    ///
794
    /// :returns: The data object for the edge
795
    /// :raises IndexError: when there is no edge present with the provided
796
    ///     index
797
    #[pyo3(text_signature = "(self, edge_index, /)")]
798
    pub fn get_edge_data_by_index(&self, edge_index: usize) -> PyResult<&PyObject> {
4✔
799
        let data = match self.graph.edge_weight(EdgeIndex::new(edge_index)) {
4✔
800
            Some(data) => data,
2✔
801
            None => {
802
                return Err(PyIndexError::new_err(format!(
2✔
803
                    "Provided edge index {} is not present in the graph",
2✔
804
                    edge_index
2✔
805
                )));
2✔
806
            }
807
        };
808
        Ok(data)
2✔
809
    }
4✔
810

811
    /// Return the edge endpoints for the edge by its given index
812
    ///
813
    /// :param int edge_index: The edge index to get the endpoints for
814
    ///
815
    /// :returns: The endpoint tuple for the edge
816
    /// :rtype: tuple
817
    /// :raises IndexError: when there is no edge present with the provided
818
    ///     index
819
    #[pyo3(text_signature = "(self, edge_index, /)")]
820
    pub fn get_edge_endpoints_by_index(&self, edge_index: usize) -> PyResult<(usize, usize)> {
4✔
821
        let endpoints = match self.graph.edge_endpoints(EdgeIndex::new(edge_index)) {
4✔
822
            Some(endpoints) => (endpoints.0.index(), endpoints.1.index()),
2✔
823
            None => {
824
                return Err(PyIndexError::new_err(format!(
2✔
825
                    "Provided edge index {} is not present in the graph",
2✔
826
                    edge_index
2✔
827
                )));
2✔
828
            }
829
        };
830
        Ok(endpoints)
2✔
831
    }
4✔
832

833
    /// Update an edge's weight/payload inplace
834
    ///
835
    /// If there are parallel edges in the graph only one edge will be updated.
836
    /// if you need to update a specific edge or need to ensure all parallel
837
    /// edges get updated you should use
838
    /// :meth:`~rustworkx.PyDiGraph.update_edge_by_index` instead.
839
    ///
840
    /// :param int source: The index for the first node
841
    /// :param int target: The index for the second node
842
    ///
843
    /// :raises NoEdgeBetweenNodes: When there is no edge between nodes
844
    #[pyo3(text_signature = "(self, source, target, edge, /)")]
845
    pub fn update_edge(&mut self, source: usize, target: usize, edge: PyObject) -> PyResult<()> {
4✔
846
        let index_a = NodeIndex::new(source);
4✔
847
        let index_b = NodeIndex::new(target);
4✔
848
        let edge_index = match self.graph.find_edge(index_a, index_b) {
4✔
849
            Some(edge_index) => edge_index,
2✔
850
            None => return Err(NoEdgeBetweenNodes::new_err("No edge found between nodes")),
2✔
851
        };
852
        let data = self.graph.edge_weight_mut(edge_index).unwrap();
2✔
853
        *data = edge;
2✔
854
        Ok(())
2✔
855
    }
4✔
856

857
    /// Update an edge's weight/payload by the edge index
858
    ///
859
    /// :param int edge_index: The index for the edge
860
    /// :param object edge: The data payload/weight to update the edge with
861
    ///
862
    /// :raises IndexError: when there is no edge present with the provided
863
    ///     index
864
    #[pyo3(text_signature = "(self, edge_index, edge, /)")]
865
    pub fn update_edge_by_index(&mut self, edge_index: usize, edge: PyObject) -> PyResult<()> {
4,190✔
866
        match self.graph.edge_weight_mut(EdgeIndex::new(edge_index)) {
4,190✔
867
            Some(data) => *data = edge,
4,188✔
868
            None => return Err(PyIndexError::new_err("No edge found for index")),
2✔
869
        };
870
        Ok(())
4,188✔
871
    }
4,190✔
872

873
    /// Return the node data for a given node index
874
    ///
875
    /// :param int node: The index for the node
876
    ///
877
    /// :returns: The data object set for that node
878
    /// :raises IndexError: when an invalid node index is provided
879
    #[pyo3(text_signature = "(self, node, /)")]
880
    pub fn get_node_data(&self, node: usize) -> PyResult<&PyObject> {
96✔
881
        let index = NodeIndex::new(node);
96✔
882
        let node = match self.graph.node_weight(index) {
96✔
883
            Some(node) => node,
94✔
884
            None => return Err(PyIndexError::new_err("No node found for index")),
2✔
885
        };
886
        Ok(node)
94✔
887
    }
96✔
888

889
    /// Return the edge data for all the edges between 2 nodes.
890
    ///
891
    /// :param int node_a: The index for the first node
892
    /// :param int node_b: The index for the second node
893
    ///
894
    /// :returns: A list with all the data objects for the edges between nodes
895
    /// :rtype: list
896
    /// :raises NoEdgeBetweenNodes: When there is no edge between nodes
897
    #[pyo3(text_signature = "(self, node_a, node_b, /)")]
898
    pub fn get_all_edge_data(&self, node_a: usize, node_b: usize) -> PyResult<Vec<&PyObject>> {
22✔
899
        let index_a = NodeIndex::new(node_a);
22✔
900
        let index_b = NodeIndex::new(node_b);
22✔
901
        let raw_edges = self
22✔
902
            .graph
22✔
903
            .edges_directed(index_a, petgraph::Direction::Outgoing);
22✔
904
        let out: Vec<&PyObject> = raw_edges
22✔
905
            .filter(|x| x.target() == index_b)
28✔
906
            .map(|edge| edge.weight())
24✔
907
            .collect();
22✔
908
        if out.is_empty() {
22✔
909
            Err(NoEdgeBetweenNodes::new_err("No edge found between nodes"))
2✔
910
        } else {
911
            Ok(out)
20✔
912
        }
913
    }
22✔
914

915
    /// Get edge list
916
    ///
917
    /// Returns a list of tuples of the form ``(source, target)`` where
918
    /// ``source`` and ``target`` are the node indices.
919
    ///
920
    /// :returns: An edge list without weights
921
    /// :rtype: EdgeList
922
    pub fn edge_list(&self) -> EdgeList {
208✔
923
        EdgeList {
208✔
924
            edges: self
208✔
925
                .graph
208✔
926
                .edge_references()
208✔
927
                .map(|edge| (edge.source().index(), edge.target().index()))
2,002,866✔
928
                .collect(),
208✔
929
        }
208✔
930
    }
208✔
931

932
    /// Get edge list with weights
933
    ///
934
    /// Returns a list of tuples of the form ``(source, target, weight)`` where
935
    /// ``source`` and ``target`` are the node indices and ``weight`` is the
936
    /// payload of the edge.
937
    ///
938
    /// :returns: An edge list with weights
939
    /// :rtype: WeightedEdgeList
940
    pub fn weighted_edge_list(&self, py: Python) -> WeightedEdgeList {
178✔
941
        WeightedEdgeList {
178✔
942
            edges: self
178✔
943
                .graph
178✔
944
                .edge_references()
178✔
945
                .map(|edge| {
29,260✔
946
                    (
29,260✔
947
                        edge.source().index(),
29,260✔
948
                        edge.target().index(),
29,260✔
949
                        edge.weight().clone_ref(py),
29,260✔
950
                    )
29,260✔
951
                })
29,260✔
952
                .collect(),
178✔
953
        }
178✔
954
    }
178✔
955

956
    /// Get an edge index map
957
    ///
958
    /// Returns a read only mapping from edge indices to the weighted edge
959
    /// tuple. The return is a mapping of the form:
960
    /// ``{0: (0, 1, "weight"), 1: (2, 3, 2.3)}``
961
    ///
962
    /// :returns: An edge index map
963
    /// :rtype: EdgeIndexMap
964
    #[pyo3(text_signature = "(self)")]
965
    pub fn edge_index_map(&self, py: Python) -> EdgeIndexMap {
64✔
966
        EdgeIndexMap {
64✔
967
            edge_map: self
64✔
968
                .graph
64✔
969
                .edge_references()
64✔
970
                .map(|edge| {
4,246✔
971
                    (
4,246✔
972
                        edge.id().index(),
4,246✔
973
                        (
4,246✔
974
                            edge.source().index(),
4,246✔
975
                            edge.target().index(),
4,246✔
976
                            edge.weight().clone_ref(py),
4,246✔
977
                        ),
4,246✔
978
                    )
4,246✔
979
                })
4,246✔
980
                .collect(),
64✔
981
        }
64✔
982
    }
64✔
983

984
    /// Remove a node from the graph.
985
    ///
986
    /// :param int node: The index of the node to remove. If the index is not
987
    ///     present in the graph it will be ignored and this function will have
988
    ///     no effect.
989
    #[pyo3(text_signature = "(self, node, /)")]
990
    pub fn remove_node(&mut self, node: usize) -> PyResult<()> {
196✔
991
        let index = NodeIndex::new(node);
196✔
992
        self.graph.remove_node(index);
196✔
993
        self.node_removed = true;
196✔
994
        Ok(())
196✔
995
    }
196✔
996

997
    /// Remove a node from the graph and add edges from all predecessors to all
998
    /// successors
999
    ///
1000
    /// By default the data/weight on edges into the removed node will be used
1001
    /// for the retained edges.
1002
    ///
1003
    /// This function has a minimum time complexity of :math:`\mathcal O(e_i e_o)`, where
1004
    /// :math:`e_i` and :math:`e_o` are the numbers of incoming and outgoing edges respectively.
1005
    /// If your ``condition`` can be cast as an equality between two hashable quantities, consider
1006
    /// using :meth:`remove_node_retain_edges_by_key` instead, or if your ``condition`` is
1007
    /// referential object identity of the edge weights, consider
1008
    /// :meth:`remove_node_retain_edges_by_id`.
1009
    ///
1010
    /// :param int node: The index of the node to remove. If the index is not
1011
    ///     present in the graph it will be ignored and this function will have
1012
    ///     no effect.
1013
    /// :param bool use_outgoing: If set to true the weight/data from the
1014
    ///     edge outgoing from ``node`` will be used in the retained edge
1015
    ///     instead of the default weight/data from the incoming edge.
1016
    /// :param condition: A callable that will be passed 2 edge weight/data
1017
    ///     objects, one from the incoming edge to ``node`` the other for the
1018
    ///     outgoing edge, and will return a ``bool`` on whether an edge should
1019
    ///     be retained. For example setting this kwarg to::
1020
    ///
1021
    ///         lambda in_edge, out_edge: in_edge == out_edge
1022
    ///
1023
    ///     would only retain edges if the input edge to ``node`` had the same
1024
    ///     data payload as the outgoing edge.
1025
    #[pyo3(text_signature = "(self, node, /, use_outgoing=False, condition=None)")]
1026
    #[pyo3(signature=(node, use_outgoing=false, condition=None))]
1027
    pub fn remove_node_retain_edges(
14✔
1028
        &mut self,
14✔
1029
        py: Python,
14✔
1030
        node: usize,
14✔
1031
        use_outgoing: bool,
14✔
1032
        condition: Option<PyObject>,
14✔
1033
    ) -> PyResult<()> {
14✔
1034
        let index = NodeIndex::new(node);
14✔
1035
        let mut edge_list: Vec<(NodeIndex, NodeIndex, PyObject)> = Vec::new();
14✔
1036

1037
        fn check_condition(
28✔
1038
            py: Python,
28✔
1039
            condition: &Option<PyObject>,
28✔
1040
            in_weight: &PyObject,
28✔
1041
            out_weight: &PyObject,
28✔
1042
        ) -> PyResult<bool> {
28✔
1043
            match condition {
28✔
1044
                Some(condition) => {
8✔
1045
                    let res = condition.call1(py, (in_weight, out_weight))?;
8✔
1046
                    Ok(res.extract(py)?)
8✔
1047
                }
1048
                None => Ok(true),
20✔
1049
            }
1050
        }
28✔
1051

1052
        for (source, in_weight) in self
18✔
1053
            .graph
14✔
1054
            .edges_directed(index, petgraph::Direction::Incoming)
14✔
1055
            .map(|x| (x.source(), x.weight()))
18✔
1056
        {
1057
            for (target, out_weight) in self
28✔
1058
                .graph
18✔
1059
                .edges_directed(index, petgraph::Direction::Outgoing)
18✔
1060
                .map(|x| (x.target(), x.weight()))
28✔
1061
            {
1062
                let weight = if use_outgoing { out_weight } else { in_weight };
28✔
1063
                if check_condition(py, &condition, in_weight, out_weight)? {
28✔
1064
                    edge_list.push((source, target, weight.clone_ref(py)));
24✔
1065
                }
24✔
1066
            }
1067
        }
1068
        for (source, target, weight) in edge_list {
38✔
1069
            self._add_edge(source, target, weight)?;
24✔
1070
        }
1071
        self.node_removed = self.graph.remove_node(index).is_some();
14✔
1072
        Ok(())
14✔
1073
    }
14✔
1074

1075
    /// Remove a node from the graph and add edges from predecessors to successors in cases where
1076
    /// an incoming and outgoing edge have the same weight by Python object identity.
1077
    ///
1078
    /// This function has a minimum time complexity of :math:`\mathcal O(e_i + e_o)`, where
1079
    /// :math:`e_i` is the number of incoming edges and :math:`e_o` the number of outgoing edges
1080
    /// (the full complexity depends on the number of new edges to be created).
1081
    ///
1082
    /// Edges will be added between all pairs of predecessor and successor nodes that have the same
1083
    /// weight.  As a consequence, any weight which appears only on predecessor edges will not
1084
    /// appear in the output, as there are no successors to pair it with.
1085
    ///
1086
    /// :param int node: The index of the node to remove. If the index is not present in the graph
1087
    ///     it will be ignored and this function will have no effect.
1088
    #[pyo3(signature=(node, /))]
1089
    pub fn remove_node_retain_edges_by_id(&mut self, py: Python, node: usize) -> PyResult<()> {
8✔
1090
        // As many indices as will fit inline within the minimum inline size of a `SmallVec`.  Many
1091
        // use cases of this likely only have one inbound and outbound edge with each id anyway.
1092
        const INLINE_SIZE: usize =
1093
            2 * ::std::mem::size_of::<usize>() / ::std::mem::size_of::<NodeIndex>();
1094
        let new_node_list = || SmallVec::<[NodeIndex; INLINE_SIZE]>::new();
44✔
1095
        let node_index = NodeIndex::new(node);
8✔
1096
        let in_edges = {
8✔
1097
            let mut in_edges = HashMap::new();
8✔
1098
            for edge in self
26✔
1099
                .graph
8✔
1100
                .edges_directed(node_index, petgraph::Direction::Incoming)
8✔
1101
            {
26✔
1102
                in_edges
26✔
1103
                    .entry(PyAnyId(edge.weight().clone_ref(py)))
26✔
1104
                    .or_insert_with(new_node_list)
26✔
1105
                    .push(edge.source());
26✔
1106
            }
26✔
1107
            in_edges
8✔
1108
        };
1109
        let mut out_edges = {
8✔
1110
            let mut out_edges = HashMap::new();
8✔
1111
            for edge in self
26✔
1112
                .graph
8✔
1113
                .edges_directed(node_index, petgraph::Direction::Outgoing)
8✔
1114
            {
26✔
1115
                out_edges
26✔
1116
                    .entry(PyAnyId(edge.weight().clone_ref(py)))
26✔
1117
                    .or_insert_with(new_node_list)
26✔
1118
                    .push(edge.target());
26✔
1119
            }
26✔
1120
            out_edges
8✔
1121
        };
1122

1123
        for (weight, in_edges_subset) in in_edges {
30✔
1124
            let out_edges_subset = match out_edges.remove(&weight) {
22✔
1125
                Some(out_edges_key) => out_edges_key,
20✔
1126
                None => continue,
2✔
1127
            };
1128
            for source in in_edges_subset {
44✔
1129
                for target in out_edges_subset.iter() {
30✔
1130
                    self._add_edge(source, *target, weight.clone_ref(py))?;
30✔
1131
                }
1132
            }
1133
        }
1134
        self.node_removed = self.graph.remove_node(node_index).is_some();
8✔
1135
        Ok(())
8✔
1136
    }
8✔
1137

1138
    /// Remove a node from the graph and add edges from predecessors to successors in cases where
1139
    /// an incoming and outgoing edge have the same weight by Python object equality.
1140
    ///
1141
    /// This function has a minimum time complexity of :math:`\mathcal O(e_i + e_o)`, where
1142
    /// :math:`e_i` is the number of incoming edges and :math:`e_o` the number of outgoing edges
1143
    /// (the full complexity depends on the number of new edges to be created).
1144
    ///
1145
    /// Edges will be added between all pairs of predecessor and successor nodes that have equal
1146
    /// weights.  As a consequence, any weight which appears only on predecessor edges will not
1147
    /// appear in the output, as there are no successors to pair it with.
1148
    ///
1149
    /// If there are multiple edges with the same weight, the exact Python object used on the new
1150
    /// edges is an implementation detail and may change.  The only guarantees are that it will be
1151
    /// deterministic for a given graph, and that it will be drawn from the incoming edges if
1152
    /// ``use_outgoing=False`` (the default) or from the outgoing edges if ``use_outgoing=True``.
1153
    ///
1154
    /// :param int node: The index of the node to remove. If the index is not present in the graph
1155
    ///     it will be ignored and this function will have no effect.
1156
    /// :param key: A callable Python object that is called once for each connected edge, to
1157
    ///     generate the "key" for that weight.  It is passed exactly one argument positionally
1158
    ///     (the weight of the edge), and should return a Python object that is hashable and
1159
    ///     implements equality checking with all other relevant keys.  If not given, the edge
1160
    ///     weights will be used directly.
1161
    /// :param bool use_outgoing: If ``False`` (default), the new edges will use the weight from
1162
    ///     one of the incoming edges.  If ``True``, they will instead use a weight from one of the
1163
    ///     outgoing edges.
1164
    #[pyo3(signature=(node, /, key=None, *, use_outgoing=false))]
1165
    pub fn remove_node_retain_edges_by_key(
4✔
1166
        &mut self,
4✔
1167
        py: Python,
4✔
1168
        node: usize,
4✔
1169
        key: Option<Py<PyAny>>,
4✔
1170
        use_outgoing: bool,
4✔
1171
    ) -> PyResult<()> {
4✔
1172
        let node_index = NodeIndex::new(node);
4✔
1173
        let in_edges = {
4✔
1174
            let in_edges = PyDict::new(py);
4✔
1175
            for edge in self
14✔
1176
                .graph
4✔
1177
                .edges_directed(node_index, petgraph::Direction::Incoming)
4✔
1178
            {
1179
                let key_value = if let Some(key_fn) = &key {
14✔
1180
                    key_fn.call1(py, (edge.weight(),))?
14✔
1181
                } else {
1182
                    edge.weight().clone_ref(py)
×
1183
                };
1184
                if let Some(edge_data) = in_edges.get_item(key_value.bind(py))? {
14✔
1185
                    let edge_data = edge_data.downcast::<RemoveNodeEdgeValue>()?;
4✔
1186
                    edge_data.borrow_mut().nodes.push(edge.source());
4✔
1187
                } else {
1188
                    in_edges.set_item(
10✔
1189
                        key_value,
10✔
1190
                        RemoveNodeEdgeValue {
10✔
1191
                            weight: edge.weight().clone_ref(py),
10✔
1192
                            nodes: vec![edge.source()],
10✔
1193
                        }
10✔
1194
                        .into_pyobject(py)?,
10✔
UNCOV
1195
                    )?
×
1196
                }
1197
            }
1198
            in_edges
4✔
1199
        };
1200
        let out_edges = {
4✔
1201
            let out_edges = PyDict::new(py);
4✔
1202
            for edge in self
14✔
1203
                .graph
4✔
1204
                .edges_directed(node_index, petgraph::Direction::Outgoing)
4✔
1205
            {
1206
                let key_value = if let Some(key_fn) = &key {
14✔
1207
                    key_fn.call1(py, (edge.weight(),))?
14✔
1208
                } else {
1209
                    edge.weight().clone_ref(py)
×
1210
                };
1211
                if let Some(edge_data) = out_edges.get_item(key_value.bind(py))? {
14✔
1212
                    let edge_data = edge_data.downcast::<RemoveNodeEdgeValue>()?;
4✔
1213
                    edge_data.borrow_mut().nodes.push(edge.target());
4✔
1214
                } else {
1215
                    out_edges.set_item(
10✔
1216
                        key_value,
10✔
1217
                        RemoveNodeEdgeValue {
10✔
1218
                            weight: edge.weight().clone_ref(py),
10✔
1219
                            nodes: vec![edge.target()],
10✔
1220
                        }
10✔
1221
                        .into_pyobject(py)?,
10✔
UNCOV
1222
                    )?
×
1223
                }
1224
            }
1225
            out_edges
4✔
1226
        };
1227

1228
        for (in_key, in_edge_data) in in_edges {
14✔
1229
            let in_edge_data = in_edge_data.downcast::<RemoveNodeEdgeValue>()?.borrow();
10✔
1230
            let out_edge_data = match out_edges.get_item(in_key)? {
10✔
1231
                Some(out_edge_data) => out_edge_data.downcast::<RemoveNodeEdgeValue>()?.borrow(),
8✔
1232
                None => continue,
2✔
1233
            };
1234
            for source in in_edge_data.nodes.iter() {
12✔
1235
                for target in out_edge_data.nodes.iter() {
18✔
1236
                    let weight = if use_outgoing {
18✔
1237
                        out_edge_data.weight.clone_ref(py)
2✔
1238
                    } else {
1239
                        in_edge_data.weight.clone_ref(py)
16✔
1240
                    };
1241
                    self._add_edge(*source, *target, weight)?;
18✔
1242
                }
1243
            }
1244
        }
1245
        self.node_removed = self.graph.remove_node(node_index).is_some();
4✔
1246
        Ok(())
4✔
1247
    }
4✔
1248

1249
    /// Add an edge between 2 nodes.
1250
    ///
1251
    /// Use add_child() or add_parent() to create a node with an edge at the
1252
    /// same time as an edge for better performance. Using this method
1253
    /// allows for adding duplicate edges between nodes if the ``multigraph``
1254
    /// attribute is set to ``True``.
1255
    ///
1256
    /// :param int parent: Index of the parent node
1257
    /// :param int child: Index of the child node
1258
    /// :param edge: The object to set as the data for the edge. It can be any
1259
    ///     python object.
1260
    ///
1261
    /// :returns: The edge index of the created edge
1262
    /// :rtype: int
1263
    ///
1264
    /// :raises: When the new edge will create a cycle
1265
    #[pyo3(text_signature = "(self, parent, child, edge, /)")]
1266
    pub fn add_edge(&mut self, parent: usize, child: usize, edge: PyObject) -> PyResult<usize> {
2,024,072✔
1267
        let p_index = NodeIndex::new(parent);
2,024,072✔
1268
        let c_index = NodeIndex::new(child);
2,024,072✔
1269
        if !self.graph.contains_node(p_index) || !self.graph.contains_node(c_index) {
2,024,072✔
1270
            return Err(PyIndexError::new_err(
6✔
1271
                "One of the endpoints of the edge does not exist in graph",
6✔
1272
            ));
6✔
1273
        }
2,024,066✔
1274
        let out_index = self._add_edge(p_index, c_index, edge)?;
2,024,066✔
1275
        Ok(out_index)
2,024,058✔
1276
    }
2,024,072✔
1277

1278
    /// Add new edges to the dag.
1279
    ///
1280
    /// :param iterable obj_list: An iterable of tuples of the form
1281
    ///     ``(parent, child, obj)`` to attach to the graph. ``parent`` and
1282
    ///     ``child`` are integer indices describing where an edge should be
1283
    ///     added, and obj is the python object for the edge data.
1284
    ///
1285
    /// :returns: A list of int indices of the newly created edges
1286
    /// :rtype: list
1287
    #[pyo3(text_signature = "(self, obj_list, /)")]
1288
    pub fn add_edges_from(&mut self, obj_list: Bound<'_, PyAny>) -> PyResult<Vec<usize>> {
322✔
1289
        let mut out_list = Vec::new();
322✔
1290
        for py_obj in obj_list.try_iter()? {
2,021,778✔
1291
            let obj = py_obj?.extract::<(usize, usize, PyObject)>()?;
2,021,778✔
1292
            let edge = self.add_edge(obj.0, obj.1, obj.2)?;
2,021,778✔
1293
            out_list.push(edge);
2,021,774✔
1294
        }
1295
        Ok(out_list)
318✔
1296
    }
322✔
1297

1298
    /// Add new edges to the dag without python data.
1299
    ///
1300
    /// :param iterable obj_list: An iterable of tuples of the form
1301
    ///     ``(parent, child)`` to attach to the graph. ``parent`` and
1302
    ///     ``child`` are integer indices describing where an edge should be
1303
    ///     added. Unlike :meth:`add_edges_from` there is no data payload and
1304
    ///     when the edge is created None will be used.
1305
    ///
1306
    /// :returns: A list of int indices of the newly created edges
1307
    /// :rtype: list
1308
    #[pyo3(text_signature = "(self, obj_list, /)")]
1309
    pub fn add_edges_from_no_data(
184✔
1310
        &mut self,
184✔
1311
        py: Python,
184✔
1312
        obj_list: Bound<'_, PyAny>,
184✔
1313
    ) -> PyResult<Vec<usize>> {
184✔
1314
        let mut out_list = Vec::new();
184✔
1315
        for py_obj in obj_list.try_iter()? {
1,406✔
1316
            let obj = py_obj?.extract::<(usize, usize)>()?;
1,406✔
1317
            let edge = self.add_edge(obj.0, obj.1, py.None())?;
1,406✔
1318
            out_list.push(edge);
1,402✔
1319
        }
1320
        Ok(out_list)
180✔
1321
    }
184✔
1322

1323
    /// Extend graph from an edge list
1324
    ///
1325
    /// This method differs from :meth:`add_edges_from_no_data` in that it will
1326
    /// add nodes if a node index is not present in the edge list.
1327
    ///
1328
    /// :param iterable edge_list: An iterable of tuples of the form ``(source, target)``
1329
    ///     where source and target are integer node indices. If the node index
1330
    ///     is not present in the graph, nodes will be added (with a node
1331
    ///     weight of ``None``) to that index.
1332
    #[pyo3(text_signature = "(self, edge_list, /)")]
1333
    pub fn extend_from_edge_list(
154✔
1334
        &mut self,
154✔
1335
        py: Python,
154✔
1336
        edge_list: Bound<'_, PyAny>,
154✔
1337
    ) -> PyResult<()> {
154✔
1338
        for py_obj in edge_list.try_iter()? {
1,600✔
1339
            let (source, target) = py_obj?.extract::<(usize, usize)>()?;
1,600✔
1340
            let max_index = cmp::max(source, target);
1,600✔
1341
            while max_index >= self.node_count() {
2,698✔
1342
                self.graph.add_node(py.None());
1,098✔
1343
            }
1,098✔
1344
            self._add_edge(NodeIndex::new(source), NodeIndex::new(target), py.None())?;
1,600✔
1345
        }
1346
        Ok(())
152✔
1347
    }
154✔
1348

1349
    /// Extend graph from a weighted edge list
1350
    ///
1351
    /// This method differs from :meth:`add_edges_from` in that it will
1352
    /// add nodes if a node index is not present in the edge list.
1353
    ///
1354
    /// :param iterable edge_list: An iterable of tuples of the form
1355
    ///     ``(source, target, weight)`` where source and target are integer
1356
    ///     node indices. If the node index is not present in the graph
1357
    ///     nodes will be added (with a node weight of ``None``) to that index.
1358
    #[pyo3(text_signature = "(self, edge_list, /)")]
1359
    pub fn extend_from_weighted_edge_list(
62✔
1360
        &mut self,
62✔
1361
        py: Python,
62✔
1362
        edge_list: Bound<'_, PyAny>,
62✔
1363
    ) -> PyResult<()> {
62✔
1364
        for py_obj in edge_list.try_iter()? {
312✔
1365
            let (source, target, weight) = py_obj?.extract::<(usize, usize, PyObject)>()?;
312✔
1366
            let max_index = cmp::max(source, target);
312✔
1367
            while max_index >= self.node_count() {
536✔
1368
                self.graph.add_node(py.None());
224✔
1369
            }
224✔
1370
            self._add_edge(NodeIndex::new(source), NodeIndex::new(target), weight)?;
312✔
1371
        }
1372
        Ok(())
60✔
1373
    }
62✔
1374

1375
    /// Insert a node between a list of reference nodes and all their predecessors
1376
    ///
1377
    /// This essentially iterates over all edges into the reference node
1378
    /// specified in the ``ref_nodes`` parameter removes those edges and then
1379
    /// adds 2 edges, one from the predecessor of ``ref_node`` to ``node``
1380
    /// and the other from ``node`` to ``ref_node``. The edge payloads for
1381
    /// the newly created edges are copied by reference from the original
1382
    /// edge that gets removed.
1383
    ///
1384
    /// :param int node: The node index to insert between
1385
    /// :param int ref_node: The reference node index to insert ``node``
1386
    ///     between
1387
    #[pyo3(text_signature = "(self, node, ref_nodes, /)")]
1388
    pub fn insert_node_on_in_edges_multiple(
8✔
1389
        &mut self,
8✔
1390
        py: Python,
8✔
1391
        node: usize,
8✔
1392
        ref_nodes: Vec<usize>,
8✔
1393
    ) -> PyResult<()> {
8✔
1394
        for ref_node in ref_nodes {
22✔
1395
            self.insert_between(py, node, ref_node, false)?;
14✔
1396
        }
1397
        Ok(())
8✔
1398
    }
8✔
1399

1400
    /// Insert a node between a list of reference nodes and all their successors
1401
    ///
1402
    /// This essentially iterates over all edges out of the reference node
1403
    /// specified in the ``ref_node`` parameter removes those edges and then
1404
    /// adds 2 edges, one from ``ref_node`` to ``node`` and the other from
1405
    /// ``node`` to the successor of ``ref_node``. The edge payloads for the
1406
    /// newly created edges are copied by reference from the original edge that
1407
    /// gets removed.
1408
    ///
1409
    /// :param int node: The node index to insert between
1410
    /// :param int ref_nodes: The list of node indices to insert ``node``
1411
    ///     between
1412
    #[pyo3(text_signature = "(self, node, ref_nodes, /)")]
1413
    pub fn insert_node_on_out_edges_multiple(
8✔
1414
        &mut self,
8✔
1415
        py: Python,
8✔
1416
        node: usize,
8✔
1417
        ref_nodes: Vec<usize>,
8✔
1418
    ) -> PyResult<()> {
8✔
1419
        for ref_node in ref_nodes {
22✔
1420
            self.insert_between(py, node, ref_node, true)?;
14✔
1421
        }
1422
        Ok(())
8✔
1423
    }
8✔
1424

1425
    /// Insert a node between a reference node and all its predecessor nodes
1426
    ///
1427
    /// This essentially iterates over all edges into the reference node
1428
    /// specified in the ``ref_node`` parameter removes those edges and then
1429
    /// adds 2 edges, one from the predecessor of ``ref_node`` to ``node`` and
1430
    /// the other from ``node`` to ``ref_node``. The edge payloads for the
1431
    /// newly created edges are copied by reference from the original edge that
1432
    /// gets removed.
1433
    ///
1434
    /// :param int node: The node index to insert between
1435
    /// :param int ref_node: The reference node index to insert ``node``
1436
    ///     between
1437
    #[pyo3(text_signature = "(self, node, ref_node, /)")]
1438
    pub fn insert_node_on_in_edges(
8✔
1439
        &mut self,
8✔
1440
        py: Python,
8✔
1441
        node: usize,
8✔
1442
        ref_node: usize,
8✔
1443
    ) -> PyResult<()> {
8✔
1444
        self.insert_between(py, node, ref_node, false)?;
8✔
1445
        Ok(())
8✔
1446
    }
8✔
1447

1448
    /// Insert a node between a reference node and all its successor nodes
1449
    ///
1450
    /// This essentially iterates over all edges out of the reference node
1451
    /// specified in the ``ref_node`` parameter removes those edges and then
1452
    /// adds 2 edges, one from ``ref_node`` to ``node`` and the other from
1453
    /// ``node`` to the successor of ``ref_node``. The edge payloads for the
1454
    /// newly created edges are copied by reference from the original edge
1455
    /// that gets removed.
1456
    ///
1457
    /// :param int node: The node index to insert between
1458
    /// :param int ref_node: The reference node index to insert ``node``
1459
    ///     between
1460
    #[pyo3(text_signature = "(self, node, ref_node, /)")]
1461
    pub fn insert_node_on_out_edges(
8✔
1462
        &mut self,
8✔
1463
        py: Python,
8✔
1464
        node: usize,
8✔
1465
        ref_node: usize,
8✔
1466
    ) -> PyResult<()> {
8✔
1467
        self.insert_between(py, node, ref_node, true)?;
8✔
1468
        Ok(())
8✔
1469
    }
8✔
1470

1471
    /// Remove an edge between 2 nodes.
1472
    ///
1473
    /// Note if there are multiple edges between the specified nodes only one
1474
    /// will be removed.
1475
    ///
1476
    /// :param int parent: The index for the parent node.
1477
    /// :param int child: The index of the child node.
1478
    ///
1479
    /// :raises NoEdgeBetweenNodes: If there are no edges between the nodes
1480
    ///     specified
1481
    #[pyo3(text_signature = "(self, parent, child, /)")]
1482
    pub fn remove_edge(&mut self, parent: usize, child: usize) -> PyResult<()> {
8✔
1483
        let p_index = NodeIndex::new(parent);
8✔
1484
        let c_index = NodeIndex::new(child);
8✔
1485
        let edge_index = match self.graph.find_edge(p_index, c_index) {
8✔
1486
            Some(edge_index) => edge_index,
4✔
1487
            None => return Err(NoEdgeBetweenNodes::new_err("No edge found between nodes")),
4✔
1488
        };
1489
        self.graph.remove_edge(edge_index);
4✔
1490
        Ok(())
4✔
1491
    }
8✔
1492

1493
    /// Remove an edge identified by the provided index
1494
    ///
1495
    /// :param int edge: The index of the edge to remove
1496
    #[pyo3(text_signature = "(self, edge, /)")]
1497
    pub fn remove_edge_from_index(&mut self, edge: usize) -> PyResult<()> {
12✔
1498
        let edge_index = EdgeIndex::new(edge);
12✔
1499
        self.graph.remove_edge(edge_index);
12✔
1500
        Ok(())
12✔
1501
    }
12✔
1502

1503
    /// Remove edges from the graph.
1504
    ///
1505
    /// Note if there are multiple edges between the specified nodes only one
1506
    /// will be removed.
1507
    ///
1508
    /// :param iterable index_list: An iterable of node index pairs to remove from
1509
    ///     the graph
1510
    ///
1511
    /// :raises NoEdgeBetweenNodes: If there are no edges between a specified
1512
    ///     pair of nodes.
1513
    #[pyo3(text_signature = "(self, index_list, /)")]
1514
    pub fn remove_edges_from(&mut self, index_list: Bound<'_, PyAny>) -> PyResult<()> {
6✔
1515
        for py_obj in index_list.try_iter()? {
10✔
1516
            let (x, y) = py_obj?.extract::<(usize, usize)>()?;
10✔
1517
            let (p_index, c_index) = (NodeIndex::new(x), NodeIndex::new(y));
10✔
1518
            let edge_index = match self.graph.find_edge(p_index, c_index) {
10✔
1519
                Some(edge_index) => edge_index,
8✔
1520
                None => return Err(NoEdgeBetweenNodes::new_err("No edge found between nodes")),
2✔
1521
            };
1522
            self.graph.remove_edge(edge_index);
8✔
1523
        }
1524
        Ok(())
4✔
1525
    }
6✔
1526

1527
    /// Add a new node to the graph.
1528
    ///
1529
    /// :param obj: The python object to attach to the node
1530
    ///
1531
    /// :returns: The index of the newly created node
1532
    /// :rtype: int
1533
    #[pyo3(text_signature = "(self, obj, /)")]
1534
    pub fn add_node(&mut self, obj: PyObject) -> PyResult<usize> {
403,030✔
1535
        let index = self.graph.add_node(obj);
403,030✔
1536
        Ok(index.index())
403,030✔
1537
    }
403,030✔
1538

1539
    /// Find node within this graph given a specific weight
1540
    ///
1541
    /// This algorithm has a worst case of O(n) since it searches the node
1542
    /// indices in order. If there is more than one node in the graph with the
1543
    /// same weight only the first match (by node index) will be returned.
1544
    ///
1545
    /// :param obj: The weight to look for in the graph.
1546
    ///
1547
    /// :returns: the index of the first node in the graph that is equal to the
1548
    ///     weight. If no match is found ``None`` will be returned.
1549
    /// :rtype: int
1550
    #[pyo3(text_signature = "(self, obj, /)")]
1551
    pub fn find_node_by_weight(&self, py: Python, obj: PyObject) -> PyResult<Option<usize>> {
6✔
1552
        find_node_by_weight(py, &self.graph, &obj).map(|node| node.map(|x| x.index()))
6✔
1553
    }
6✔
1554

1555
    /// Merge two nodes in the graph.
1556
    ///
1557
    /// If the nodes have equal weight objects then all the edges into and out of `u` will be added
1558
    /// to `v` and `u` will be removed from the graph. If the nodes don't have equal weight
1559
    /// objects then no changes will be made and no error raised
1560
    ///
1561
    /// :param int u: The source node that is going to be merged
1562
    /// :param int v: The target node that is going to be the new node
1563
    #[pyo3(text_signature = "(self, u, v, /)")]
1564
    pub fn merge_nodes(&mut self, py: Python, u: usize, v: usize) -> PyResult<()> {
8✔
1565
        let source_node = NodeIndex::new(u);
8✔
1566
        let target_node = NodeIndex::new(v);
8✔
1567

1568
        let source_weight = match self.graph.node_weight(source_node) {
8✔
1569
            Some(weight) => weight,
6✔
1570
            None => return Err(PyIndexError::new_err("No node found for index")),
2✔
1571
        };
1572

1573
        let target_weight = match self.graph.node_weight(target_node) {
6✔
1574
            Some(weight) => weight,
4✔
1575
            None => return Err(PyIndexError::new_err("No node found for index")),
2✔
1576
        };
1577

1578
        let have_same_weights =
4✔
1579
            source_weight.bind(py).compare(target_weight.bind(py))? == Ordering::Equal;
4✔
1580

4✔
1581
        if have_same_weights {
4✔
1582
            const DIRECTIONS: [petgraph::Direction; 2] =
1583
                [petgraph::Direction::Outgoing, petgraph::Direction::Incoming];
1584

1585
            let mut edges_to_add: Vec<(usize, usize, PyObject)> = Vec::new();
2✔
1586
            for dir in &DIRECTIONS {
6✔
1587
                for edge in self.graph.edges_directed(NodeIndex::new(u), *dir) {
4✔
1588
                    let s = edge.source();
4✔
1589
                    let d = edge.target();
4✔
1590

4✔
1591
                    if s.index() == u {
4✔
1592
                        edges_to_add.push((v, d.index(), edge.weight().clone_ref(py)));
2✔
1593
                    } else {
2✔
1594
                        edges_to_add.push((s.index(), v, edge.weight().clone_ref(py)));
2✔
1595
                    }
2✔
1596
                }
1597
            }
1598
            self.remove_node(u)?;
2✔
1599
            for edge in edges_to_add {
6✔
1600
                self.add_edge(edge.0, edge.1, edge.2)?;
4✔
1601
            }
1602
        }
2✔
1603

1604
        Ok(())
4✔
1605
    }
8✔
1606

1607
    /// Add a new child node to the graph.
1608
    ///
1609
    /// This will create a new node on the graph and add an edge from the parent
1610
    /// to that new node.
1611
    ///
1612
    /// :param int parent: The index for the parent node
1613
    /// :param obj: The python object to attach to the node
1614
    /// :param edge: The python object to attach to the edge
1615
    ///
1616
    /// :returns: The index of the newly created child node
1617
    /// :rtype: int
1618
    #[pyo3(text_signature = "(self, parent, obj, edge, /)")]
1619
    pub fn add_child(&mut self, parent: usize, obj: PyObject, edge: PyObject) -> PyResult<usize> {
401,382✔
1620
        let index = NodeIndex::new(parent);
401,382✔
1621
        let child_node = self.graph.add_node(obj);
401,382✔
1622
        self.graph.add_edge(index, child_node, edge);
401,382✔
1623
        Ok(child_node.index())
401,382✔
1624
    }
401,382✔
1625

1626
    /// Add a new parent node to the dag.
1627
    ///
1628
    /// This create a new node on the dag and add an edge to the child from
1629
    /// that new node
1630
    ///
1631
    /// :param int child: The index of the child node
1632
    /// :param obj: The python object to attach to the node
1633
    /// :param edge: The python object to attach to the edge
1634
    ///
1635
    /// :returns index: The index of the newly created parent node
1636
    /// :rtype: int
1637
    #[pyo3(text_signature = "(self, child, obj, edge, /)")]
1638
    pub fn add_parent(&mut self, child: usize, obj: PyObject, edge: PyObject) -> PyResult<usize> {
70✔
1639
        let index = NodeIndex::new(child);
70✔
1640
        let parent_node = self.graph.add_node(obj);
70✔
1641
        self.graph.add_edge(parent_node, index, edge);
70✔
1642
        Ok(parent_node.index())
70✔
1643
    }
70✔
1644

1645
    /// Get the index and data for the neighbors of a node.
1646
    ///
1647
    /// This will return a dictionary where the keys are the node indices of
1648
    /// the adjacent nodes (inbound or outbound) and the value is the edge dat
1649
    /// objects between that adjacent node and the provided node. Note in
1650
    /// the case of a multigraph only one edge will be used, not all of the
1651
    /// edges between two node.
1652
    ///
1653
    /// :param int node: The index of the node to get the neighbors
1654
    ///
1655
    /// :returns: A dictionary where the keys are node indices and the value
1656
    ///     is the edge data object for all nodes that share an edge with the
1657
    ///     specified node.
1658
    /// :rtype: dict
1659
    #[pyo3(text_signature = "(self, node, /)")]
1660
    pub fn adj(&mut self, node: usize) -> DictMap<usize, &PyObject> {
6✔
1661
        let index = NodeIndex::new(node);
6✔
1662
        self.graph
6✔
1663
            .edges_directed(index, petgraph::Direction::Incoming)
6✔
1664
            .map(|edge| (edge.source().index(), edge.weight()))
6✔
1665
            .chain(
6✔
1666
                self.graph
6✔
1667
                    .edges_directed(index, petgraph::Direction::Outgoing)
6✔
1668
                    .map(|edge| (edge.target().index(), edge.weight())),
6✔
1669
            )
6✔
1670
            .collect()
6✔
1671
    }
6✔
1672

1673
    /// Get the index and data for either the parent or children of a node.
1674
    ///
1675
    /// This will return a dictionary where the keys are the node indices of
1676
    /// the adjacent nodes (inbound or outbound as specified) and the value
1677
    /// is the edge data objects for the edges between that adjacent node
1678
    /// and the provided node. Note in the case of a multigraph only one edge
1679
    /// one edge will be used, not all of the edges between two node.
1680
    ///
1681
    /// :param int node: The index of the node to get the neighbors
1682
    /// :param bool direction: The direction to use for finding nodes,
1683
    ///     True means inbound edges and False means outbound edges.
1684
    ///
1685
    /// :returns: A dictionary where the keys are node indices and
1686
    ///     the value is the edge data object for all nodes that share an
1687
    ///     edge with the specified node.
1688
    /// :rtype: dict
1689
    #[pyo3(text_signature = "(self, node, direction, /)")]
1690
    pub fn adj_direction(&mut self, node: usize, direction: bool) -> DictMap<usize, &PyObject> {
8✔
1691
        let index = NodeIndex::new(node);
8✔
1692
        if direction {
8✔
1693
            self.graph
4✔
1694
                .edges_directed(index, petgraph::Direction::Incoming)
4✔
1695
                .map(|edge| (edge.source().index(), edge.weight()))
4✔
1696
                .collect()
4✔
1697
        } else {
1698
            self.graph
4✔
1699
                .edges_directed(index, petgraph::Direction::Outgoing)
4✔
1700
                .map(|edge| (edge.target().index(), edge.weight()))
6✔
1701
                .collect()
4✔
1702
        }
1703
    }
8✔
1704

1705
    /// Get the neighbors (i.e. successors) of a node.
1706
    ///
1707
    /// This will return a list of neighbor node indices. This function
1708
    /// is equivalent to :meth:`successor_indices`.
1709
    ///
1710
    /// :param int node: The index of the node to get the neighbors of
1711
    ///
1712
    /// :returns: A list of the neighbor node indices
1713
    /// :rtype: NodeIndices
1714
    #[pyo3(text_signature = "(self, node, /)")]
1715
    pub fn neighbors(&self, node: usize) -> NodeIndices {
8✔
1716
        NodeIndices {
8✔
1717
            nodes: self
8✔
1718
                .graph
8✔
1719
                .neighbors(NodeIndex::new(node))
8✔
1720
                .map(|node| node.index())
10✔
1721
                .collect::<HashSet<usize>>()
8✔
1722
                .drain()
8✔
1723
                .collect(),
8✔
1724
        }
8✔
1725
    }
8✔
1726

1727
    /// Get the direction-agnostic neighbors (i.e. successors and predecessors) of a node.
1728
    ///
1729
    /// This is functionally equivalent to converting the directed graph to an undirected
1730
    /// graph, and calling ``neighbors`` thereon. For example::
1731
    ///
1732
    ///     import rustworkx
1733
    ///
1734
    ///     dag = rustworkx.generators.directed_cycle_graph(num_nodes=10, bidirectional=False)
1735
    ///
1736
    ///     node = 3
1737
    ///     neighbors = dag.neighbors_undirected(node)
1738
    ///     same_neighbors = dag.to_undirected().neighbors(node)
1739
    ///
1740
    ///     assert sorted(neighbors) == sorted(same_neighbors)
1741
    ///
1742
    /// :param int node: The index of the node to get the neighbors of
1743
    ///
1744
    /// :returns: A list of the neighbor node indices
1745
    /// :rtype: NodeIndices
1746
    #[pyo3(text_signature = "(self, node, /)")]
1747
    pub fn neighbors_undirected(&self, node: usize) -> NodeIndices {
22✔
1748
        NodeIndices {
22✔
1749
            nodes: self
22✔
1750
                .graph
22✔
1751
                .neighbors_undirected(NodeIndex::new(node))
22✔
1752
                .map(|node| node.index())
42✔
1753
                .collect::<HashSet<usize>>()
22✔
1754
                .drain()
22✔
1755
                .collect(),
22✔
1756
        }
22✔
1757
    }
22✔
1758

1759
    /// Get the successor indices of a node.
1760
    ///
1761
    /// This will return a list of the node indices for the successors of
1762
    /// a node
1763
    ///
1764
    /// :param int node: The index of the node to get the successors of
1765
    ///
1766
    /// :returns: A list of the neighbor node indices
1767
    /// :rtype: NodeIndices
1768
    #[pyo3(text_signature = "(self, node, /)")]
1769
    pub fn successor_indices(&self, node: usize) -> NodeIndices {
4✔
1770
        NodeIndices {
4✔
1771
            nodes: self
4✔
1772
                .graph
4✔
1773
                .neighbors_directed(NodeIndex::new(node), petgraph::Direction::Outgoing)
4✔
1774
                .map(|node| node.index())
6✔
1775
                .collect(),
4✔
1776
        }
4✔
1777
    }
4✔
1778

1779
    /// Get the predecessor indices of a node.
1780
    ///
1781
    /// This will return a list of the node indices for the predecessors of
1782
    /// a node
1783
    ///
1784
    /// :param int node: The index of the node to get the predecessors of
1785
    ///
1786
    /// :returns: A list of the neighbor node indices
1787
    /// :rtype: NodeIndices
1788
    #[pyo3(text_signature = "(self, node, /)")]
1789
    pub fn predecessor_indices(&self, node: usize) -> NodeIndices {
96✔
1790
        NodeIndices {
96✔
1791
            nodes: self
96✔
1792
                .graph
96✔
1793
                .neighbors_directed(NodeIndex::new(node), petgraph::Direction::Incoming)
96✔
1794
                .map(|node| node.index())
108✔
1795
                .collect(),
96✔
1796
        }
96✔
1797
    }
96✔
1798

1799
    /// Return the list of edge indices incident to a provided node
1800
    ///
1801
    /// You can later retrieve the data payload of this edge with
1802
    /// :meth:`~rustworkx.PyDiGraph.get_edge_data_by_index` or its
1803
    /// endpoints with :meth:`~rustworkx.PyDiGraph.get_edge_endpoints_by_index`.
1804
    ///
1805
    /// By default this method will only return the outgoing edges of
1806
    /// the provided ``node``. If you would like to access both the
1807
    /// incoming and outgoing edges you can set the ``all_edges``
1808
    /// kwarg to ``True``.
1809
    ///
1810
    /// :param int node: The node index to get incident edges from. If
1811
    ///     this node index is not present in the graph this method will
1812
    ///     return an empty list and not error.
1813
    /// :param bool all_edges: If set to ``True`` both incoming and outgoing
1814
    ///     edges to ``node`` will be returned.
1815
    ///
1816
    /// :returns: A list of the edge indices incident to a node in the graph
1817
    /// :rtype: EdgeIndices
1818
    #[pyo3(text_signature = "(self, node, /, all_edges=False)")]
1819
    #[pyo3(signature=(node, all_edges=false))]
1820
    pub fn incident_edges(&self, node: usize, all_edges: bool) -> EdgeIndices {
6✔
1821
        let node_index = NodeIndex::new(node);
6✔
1822
        if all_edges {
6✔
1823
            EdgeIndices {
2✔
1824
                edges: self
2✔
1825
                    .graph
2✔
1826
                    .edges_directed(node_index, petgraph::Direction::Outgoing)
2✔
1827
                    .chain(
2✔
1828
                        self.graph
2✔
1829
                            .edges_directed(node_index, petgraph::Direction::Incoming),
2✔
1830
                    )
2✔
1831
                    .map(|e| e.id().index())
4✔
1832
                    .collect(),
2✔
1833
            }
2✔
1834
        } else {
1835
            EdgeIndices {
4✔
1836
                edges: self
4✔
1837
                    .graph
4✔
1838
                    .edges(node_index)
4✔
1839
                    .map(|e| e.id().index())
4✔
1840
                    .collect(),
4✔
1841
            }
4✔
1842
        }
1843
    }
6✔
1844

1845
    /// Return the index map of edges incident to a provided node
1846
    ///
1847
    /// By default this method will only return the outgoing edges of
1848
    /// the provided ``node``. If you would like to access both the
1849
    /// incoming and outgoing edges you can set the ``all_edges``
1850
    /// kwarg to ``True``.
1851
    ///
1852
    /// :param int node: The node index to get incident edges from. If
1853
    ///     this node index is not present in the graph this method will
1854
    ///     return an empty list and not error.
1855
    /// :param bool all_edges: If set to ``True`` both incoming and outgoing
1856
    ///     edges to ``node`` will be returned.
1857
    ///
1858
    /// :returns: A mapping of incident edge indices to the tuple
1859
    ///     ``(source, target, data)``
1860
    /// :rtype: EdgeIndexMap
1861
    #[pyo3(text_signature = "(self, node, /, all_edges=False)")]
1862
    #[pyo3(signature=(node, all_edges=false))]
1863
    pub fn incident_edge_index_map(
6✔
1864
        &self,
6✔
1865
        py: Python,
6✔
1866
        node: usize,
6✔
1867
        all_edges: bool,
6✔
1868
    ) -> EdgeIndexMap {
6✔
1869
        let node_index = NodeIndex::new(node);
6✔
1870
        if all_edges {
6✔
1871
            EdgeIndexMap {
2✔
1872
                edge_map: self
2✔
1873
                    .graph
2✔
1874
                    .edges_directed(node_index, petgraph::Direction::Outgoing)
2✔
1875
                    .chain(
2✔
1876
                        self.graph
2✔
1877
                            .edges_directed(node_index, petgraph::Direction::Incoming),
2✔
1878
                    )
2✔
1879
                    .map(|edge| {
4✔
1880
                        (
4✔
1881
                            edge.id().index(),
4✔
1882
                            (
4✔
1883
                                edge.source().index(),
4✔
1884
                                edge.target().index(),
4✔
1885
                                edge.weight().clone_ref(py),
4✔
1886
                            ),
4✔
1887
                        )
4✔
1888
                    })
4✔
1889
                    .collect(),
2✔
1890
            }
2✔
1891
        } else {
1892
            EdgeIndexMap {
4✔
1893
                edge_map: self
4✔
1894
                    .graph
4✔
1895
                    .edges(node_index)
4✔
1896
                    .map(|edge| {
4✔
1897
                        (
2✔
1898
                            edge.id().index(),
2✔
1899
                            (
2✔
1900
                                edge.source().index(),
2✔
1901
                                edge.target().index(),
2✔
1902
                                edge.weight().clone_ref(py),
2✔
1903
                            ),
2✔
1904
                        )
2✔
1905
                    })
4✔
1906
                    .collect(),
4✔
1907
            }
4✔
1908
        }
1909
    }
6✔
1910

1911
    /// Get the index and edge data for all parents of a node.
1912
    ///
1913
    /// This will return a list of tuples with the parent index the node index
1914
    /// and the edge data. This can be used to recreate add_edge() calls.
1915
    /// :param int node: The index of the node to get the edges for
1916
    ///
1917
    /// :param int node: The index of the node to get the edges for
1918
    ///
1919
    /// :returns: A list of tuples of the form:
1920
    ///     ``(parent_index, node_index, edge_data)```
1921
    /// :rtype: WeightedEdgeList
1922
    #[pyo3(text_signature = "(self, node, /)")]
1923
    pub fn in_edges(&self, py: Python, node: usize) -> WeightedEdgeList {
286✔
1924
        let index = NodeIndex::new(node);
286✔
1925
        let dir = petgraph::Direction::Incoming;
286✔
1926
        let raw_edges = self.graph.edges_directed(index, dir);
286✔
1927
        let out_list: Vec<(usize, usize, PyObject)> = raw_edges
286✔
1928
            .map(|x| (x.source().index(), node, x.weight().clone_ref(py)))
548✔
1929
            .collect();
286✔
1930
        WeightedEdgeList { edges: out_list }
286✔
1931
    }
286✔
1932

1933
    /// Get the index and edge data for all children of a node.
1934
    ///
1935
    /// This will return a list of tuples with the child index the node index
1936
    /// and the edge data. This can be used to recreate add_edge() calls.
1937
    ///
1938
    /// :param int node: The index of the node to get the edges for
1939
    ///
1940
    /// :returns out_edges: A list of tuples of the form:
1941
    ///     ```(node_index, child_index, edge_data)```
1942
    /// :rtype: WeightedEdgeList
1943
    #[pyo3(text_signature = "(self, node, /)")]
1944
    pub fn out_edges(&self, py: Python, node: usize) -> WeightedEdgeList {
528✔
1945
        let index = NodeIndex::new(node);
528✔
1946
        let dir = petgraph::Direction::Outgoing;
528✔
1947
        let raw_edges = self.graph.edges_directed(index, dir);
528✔
1948
        let out_list: Vec<(usize, usize, PyObject)> = raw_edges
528✔
1949
            .map(|x| (node, x.target().index(), x.weight().clone_ref(py)))
2,228✔
1950
            .collect();
528✔
1951
        WeightedEdgeList { edges: out_list }
528✔
1952
    }
528✔
1953

1954
    /// Add new nodes to the graph.
1955
    ///
1956
    /// :param iterable obj_list: An iterable of python objects to attach to the graph
1957
    ///     as new nodes
1958
    ///
1959
    /// :returns: A list of int indices of the newly created nodes
1960
    /// :rtype: NodeIndices
1961
    #[pyo3(text_signature = "(self, obj_list, /)")]
1962
    pub fn add_nodes_from(&mut self, obj_list: Bound<'_, PyAny>) -> PyResult<NodeIndices> {
404✔
1963
        let mut out_list = Vec::new();
404✔
1964
        for py_obj in obj_list.try_iter()? {
2,014,968✔
1965
            let obj = py_obj?.extract::<PyObject>()?;
2,014,968✔
1966
            out_list.push(self.graph.add_node(obj).index());
2,014,968✔
1967
        }
1968
        Ok(NodeIndices { nodes: out_list })
404✔
1969
    }
404✔
1970

1971
    /// Remove nodes from the graph.
1972
    ///
1973
    /// If a node index in the list is not present in the graph it will be
1974
    /// ignored.
1975
    ///
1976
    /// :param iterable index_list: An iterable of node indices to remove from the
1977
    ///     graph.
1978
    #[pyo3(text_signature = "(self, index_list, /)")]
1979
    pub fn remove_nodes_from(&mut self, index_list: Bound<'_, PyAny>) -> PyResult<()> {
30✔
1980
        for py_obj in index_list.try_iter()? {
58✔
1981
            let node = py_obj?.extract::<usize>()?;
58✔
1982
            self.remove_node(node)?;
58✔
1983
        }
1984
        Ok(())
30✔
1985
    }
30✔
1986

1987
    /// Get the degree of a node for inbound edges.
1988
    ///
1989
    /// :param int node: The index of the node to find the inbound degree of
1990
    ///
1991
    /// :returns: The inbound degree for the specified node
1992
    /// :rtype: int
1993
    #[pyo3(text_signature = "(self, node, /)")]
1994
    pub fn in_degree(&self, node: usize) -> usize {
96✔
1995
        let index = NodeIndex::new(node);
96✔
1996
        let dir = petgraph::Direction::Incoming;
96✔
1997
        let neighbors = self.graph.edges_directed(index, dir);
96✔
1998
        neighbors.count()
96✔
1999
    }
96✔
2000

2001
    /// Get the degree of a node for outbound edges.
2002
    ///
2003
    /// :param int node: The index of the node to find the outbound degree of
2004
    /// :returns: The outbound degree for the specified node
2005
    /// :rtype: int
2006
    #[pyo3(text_signature = "(self, node, /)")]
2007
    pub fn out_degree(&self, node: usize) -> usize {
64✔
2008
        let index = NodeIndex::new(node);
64✔
2009
        let dir = petgraph::Direction::Outgoing;
64✔
2010
        let neighbors = self.graph.edges_directed(index, dir);
64✔
2011
        neighbors.count()
64✔
2012
    }
64✔
2013

2014
    /// Find a target node with a specific edge
2015
    ///
2016
    /// This method is used to find a target node that is a adjacent to a given
2017
    /// node given an edge condition.
2018
    ///
2019
    /// :param int node: The node to use as the source of the search
2020
    /// :param callable predicate: A python callable that will take a single
2021
    ///     parameter, the edge object, and will return a boolean if the
2022
    ///     edge matches or not
2023
    ///
2024
    /// :returns: The node object that has an edge to it from the provided
2025
    ///     node index which matches the provided condition
2026
    #[pyo3(text_signature = "(self, node, predicate, /)")]
2027
    pub fn find_adjacent_node_by_edge(
4✔
2028
        &self,
4✔
2029
        py: Python,
4✔
2030
        node: usize,
4✔
2031
        predicate: PyObject,
4✔
2032
    ) -> PyResult<&PyObject> {
4✔
2033
        let predicate_callable = |a: &PyObject| -> PyResult<PyObject> {
6✔
2034
            let res = predicate.call1(py, (a,))?;
6✔
2035
            res.into_py_any(py)
6✔
2036
        };
6✔
2037
        let index = NodeIndex::new(node);
4✔
2038
        let dir = petgraph::Direction::Outgoing;
4✔
2039
        let edges = self.graph.edges_directed(index, dir);
4✔
2040
        for edge in edges {
8✔
2041
            let edge_predicate_raw = predicate_callable(edge.weight())?;
6✔
2042
            let edge_predicate: bool = edge_predicate_raw.extract(py)?;
6✔
2043
            if edge_predicate {
6✔
2044
                return Ok(self.graph.node_weight(edge.target()).unwrap());
2✔
2045
            }
4✔
2046
        }
2047
        Err(NoSuitableNeighbors::new_err("No suitable neighbor"))
2✔
2048
    }
4✔
2049

2050
    /// Find a source node with a specific edge
2051
    ///
2052
    /// This method is used to find a predecessor of
2053
    /// a given node given an edge condition.
2054
    ///
2055
    /// :param int node: The node to use as the source of the search
2056
    /// :param callable predicate: A python callable that will take a single
2057
    ///     parameter, the edge object, and will return a boolean if the
2058
    ///     edge matches or not
2059
    ///
2060
    /// :returns: The node object that has an edge from it to the provided
2061
    ///     node index which matches the provided condition
2062
    #[pyo3(text_signature = "(self, node, predicate, /)")]
2063
    pub fn find_predecessor_node_by_edge(
4✔
2064
        &self,
4✔
2065
        py: Python,
4✔
2066
        node: usize,
4✔
2067
        predicate: PyObject,
4✔
2068
    ) -> PyResult<&PyObject> {
4✔
2069
        let predicate_callable = |a: &PyObject| -> PyResult<PyObject> {
4✔
2070
            let res = predicate.call1(py, (a,))?;
4✔
2071
            res.into_py_any(py)
4✔
2072
        };
4✔
2073
        let index = NodeIndex::new(node);
4✔
2074
        let dir = petgraph::Direction::Incoming;
4✔
2075
        let edges = self.graph.edges_directed(index, dir);
4✔
2076
        for edge in edges {
6✔
2077
            let edge_predicate_raw = predicate_callable(edge.weight())?;
4✔
2078
            let edge_predicate: bool = edge_predicate_raw.extract(py)?;
4✔
2079
            if edge_predicate {
4✔
2080
                return Ok(self.graph.node_weight(edge.source()).unwrap());
2✔
2081
            }
2✔
2082
        }
2083
        Err(NoSuitableNeighbors::new_err("No suitable neighbor"))
2✔
2084
    }
4✔
2085

2086
    /// Generate a dot file from the graph
2087
    ///
2088
    /// :param node_attr: A callable that will take in a node data object
2089
    ///     and return a dictionary of attributes to be associated with the
2090
    ///     node in the dot file. The key and value of this dictionary **must**
2091
    ///     be strings. If they're not strings rustworkx will raise TypeError
2092
    ///     (unfortunately without an error message because of current
2093
    ///     limitations in the PyO3 type checking)
2094
    /// :param edge_attr: A callable that will take in an edge data object
2095
    ///     and return a dictionary of attributes to be associated with the
2096
    ///     node in the dot file. The key and value of this dictionary **must**
2097
    ///     be a string. If they're not strings rustworkx will raise TypeError
2098
    ///     (unfortunately without an error message because of current
2099
    ///     limitations in the PyO3 type checking)
2100
    /// :param dict graph_attr: An optional dictionary that specifies any graph
2101
    ///     attributes for the output dot file. The key and value of this
2102
    ///     dictionary **must** be a string. If they're not strings rustworkx
2103
    ///     will raise TypeError (unfortunately without an error message
2104
    ///     because of current limitations in the PyO3 type checking)
2105
    /// :param str filename: An optional path to write the dot file to
2106
    ///     if specified there is no return from the function
2107
    ///
2108
    /// :returns: A string with the dot file contents if filename is not
2109
    ///     specified.
2110
    /// :rtype: str
2111
    ///
2112
    /// Using this method enables you to leverage graphviz to visualize a
2113
    /// :class:`rustworkx.PyDiGraph` object. For example:
2114
    ///
2115
    /// .. jupyter-execute::
2116
    ///
2117
    ///   import os
2118
    ///   import tempfile
2119
    ///
2120
    ///   import pydot
2121
    ///   from PIL import Image
2122
    ///
2123
    ///   import rustworkx as rx
2124
    ///
2125
    ///   graph = rx.directed_gnp_random_graph(15, .25)
2126
    ///   dot_str = graph.to_dot(
2127
    ///       lambda node: dict(
2128
    ///           color='black', fillcolor='lightblue', style='filled'))
2129
    ///   dot = pydot.graph_from_dot_data(dot_str)[0]
2130
    ///
2131
    ///   with tempfile.TemporaryDirectory() as tmpdirname:
2132
    ///       tmp_path = os.path.join(tmpdirname, 'dag.png')
2133
    ///       dot.write_png(tmp_path)
2134
    ///       image = Image.open(tmp_path)
2135
    ///       os.remove(tmp_path)
2136
    ///   image
2137
    ///
2138
    #[pyo3(
2139
        text_signature = "(self, /, node_attr=None, edge_attr=None, graph_attr=None, filename=None)",
2140
        signature = (node_attr=None, edge_attr=None, graph_attr=None, filename=None)
2141
    )]
2142
    pub fn to_dot<'py>(
10✔
2143
        &self,
10✔
2144
        py: Python<'py>,
10✔
2145
        node_attr: Option<PyObject>,
10✔
2146
        edge_attr: Option<PyObject>,
10✔
2147
        graph_attr: Option<BTreeMap<String, String>>,
10✔
2148
        filename: Option<String>,
10✔
2149
    ) -> PyResult<Option<Bound<'py, PyString>>> {
10✔
2150
        match filename {
10✔
2151
            Some(filename) => {
2✔
2152
                let mut file = File::create(filename)?;
2✔
2153
                build_dot(py, &self.graph, &mut file, graph_attr, node_attr, edge_attr)?;
2✔
2154
                Ok(None)
2✔
2155
            }
2156
            None => {
2157
                let mut file = Vec::<u8>::new();
8✔
2158
                build_dot(py, &self.graph, &mut file, graph_attr, node_attr, edge_attr)?;
8✔
2159
                Ok(Some(PyString::new(py, str::from_utf8(&file)?)))
8✔
2160
            }
2161
        }
2162
    }
10✔
2163

2164
    /// Read an edge list file and create a new PyDiGraph object from the
2165
    /// contents
2166
    ///
2167
    /// The expected format for the edge list file is a line separated list
2168
    /// of delimited node ids. If there are more than 3 elements on
2169
    /// a line the 3rd on will be treated as a string weight for the edge
2170
    ///
2171
    /// :param str path: The path of the file to open
2172
    /// :param str comment: Optional character to use as a comment by default
2173
    ///     there are no comment characters
2174
    /// :param str deliminator: Optional character to use as a deliminator by
2175
    ///     default any whitespace will be used
2176
    /// :param bool labels: If set to ``True`` the first two separated fields
2177
    ///     will be treated as string labels uniquely identifying a node
2178
    ///     instead of node indices.
2179
    ///
2180
    /// For example:
2181
    ///
2182
    /// .. jupyter-execute::
2183
    ///
2184
    ///   import tempfile
2185
    ///
2186
    ///   import rustworkx as rx
2187
    ///   from rustworkx.visualization import mpl_draw
2188
    ///
2189
    ///   with tempfile.NamedTemporaryFile('wt') as fd:
2190
    ///       path = fd.name
2191
    ///       fd.write('0 1\n')
2192
    ///       fd.write('0 2\n')
2193
    ///       fd.write('0 3\n')
2194
    ///       fd.write('1 2\n')
2195
    ///       fd.write('2 3\n')
2196
    ///       fd.flush()
2197
    ///       graph = rx.PyDiGraph.read_edge_list(path)
2198
    ///   mpl_draw(graph)
2199
    ///
2200
    #[staticmethod]
2201
    #[pyo3(signature=(path, comment=None, deliminator=None, labels=false))]
2202
    #[pyo3(text_signature = "(path, /, comment=None, deliminator=None, labels=False)")]
2203
    pub fn read_edge_list(
22✔
2204
        py: Python,
22✔
2205
        path: &str,
22✔
2206
        comment: Option<String>,
22✔
2207
        deliminator: Option<String>,
22✔
2208
        labels: bool,
22✔
2209
    ) -> PyResult<PyDiGraph> {
22✔
2210
        let file = File::open(path)?;
22✔
2211
        let buf_reader = BufReader::new(file);
20✔
2212
        let mut out_graph = StablePyGraph::<Directed>::new();
20✔
2213
        let mut label_map: HashMap<String, usize> = HashMap::new();
20✔
2214
        for line_raw in buf_reader.lines() {
54✔
2215
            let line = line_raw?;
54✔
2216
            let skip = match &comment {
54✔
2217
                Some(comm) => line.trim().starts_with(comm),
36✔
2218
                None => line.trim().is_empty(),
18✔
2219
            };
2220
            if skip {
54✔
2221
                continue;
12✔
2222
            }
42✔
2223
            let line_no_comments = match &comment {
42✔
2224
                Some(comm) => line
26✔
2225
                    .find(comm)
26✔
2226
                    .map(|idx| &line[..idx])
26✔
2227
                    .unwrap_or(&line)
26✔
2228
                    .trim()
26✔
2229
                    .to_string(),
26✔
2230
                None => line,
16✔
2231
            };
2232
            let pieces: Vec<&str> = match &deliminator {
42✔
2233
                Some(del) => line_no_comments.split(del).collect(),
14✔
2234
                None => line_no_comments.split_whitespace().collect(),
28✔
2235
            };
2236
            let src: usize;
2237
            let target: usize;
2238
            if labels {
42✔
2239
                let src_str = pieces[0];
10✔
2240
                let target_str = pieces[1];
10✔
2241
                src = match label_map.get(src_str) {
10✔
2242
                    Some(index) => *index,
6✔
2243
                    None => {
2244
                        let index = out_graph.add_node(src_str.into_py_any(py)?).index();
4✔
2245
                        label_map.insert(src_str.to_string(), index);
4✔
2246
                        index
4✔
2247
                    }
2248
                };
2249
                target = match label_map.get(target_str) {
10✔
2250
                    Some(index) => *index,
2✔
2251
                    None => {
2252
                        let index = out_graph.add_node(target_str.into_py_any(py)?).index();
8✔
2253
                        label_map.insert(target_str.to_string(), index);
8✔
2254
                        index
8✔
2255
                    }
2256
                };
2257
            } else {
2258
                src = pieces[0].parse::<usize>()?;
32✔
2259
                target = pieces[1].parse::<usize>()?;
32✔
2260
                let max_index = cmp::max(src, target);
32✔
2261
                // Add nodes to graph
2262
                while max_index >= out_graph.node_count() {
78✔
2263
                    out_graph.add_node(py.None());
46✔
2264
                }
46✔
2265
            }
2266
            // Add edges to graph
2267
            let weight = if pieces.len() > 2 {
42✔
2268
                let weight_str = match &deliminator {
24✔
2269
                    Some(del) => pieces[2..].join(del),
12✔
2270
                    None => pieces[2..].join(&' '.to_string()),
12✔
2271
                };
2272
                PyString::new(py, &weight_str).into()
24✔
2273
            } else {
2274
                py.None()
18✔
2275
            };
2276
            out_graph.add_edge(NodeIndex::new(src), NodeIndex::new(target), weight);
42✔
2277
        }
2278
        Ok(PyDiGraph {
20✔
2279
            graph: out_graph,
20✔
2280
            cycle_state: algo::DfsSpace::default(),
20✔
2281
            check_cycle: false,
20✔
2282
            node_removed: false,
20✔
2283
            multigraph: true,
20✔
2284
            attrs: py.None(),
20✔
2285
        })
20✔
2286
    }
22✔
2287

2288
    /// Write an edge list file from the PyDiGraph object
2289
    ///
2290
    /// :param str path: The path to write the output file to
2291
    /// :param str deliminator: The optional character to use as a deliminator
2292
    ///     if not specified ``" "`` is used.
2293
    /// :param callable weight_fn: An optional callback function that will be
2294
    ///     passed an edge's data payload/weight object and is expected to
2295
    ///     return a string (a ``TypeError`` will be raised if it doesn't
2296
    ///     return a string). If specified the weight in the output file
2297
    ///     for each edge will be set to the returned string.
2298
    ///
2299
    ///  For example:
2300
    ///
2301
    ///  .. jupyter-execute::
2302
    ///
2303
    ///     import os
2304
    ///     import tempfile
2305
    ///
2306
    ///     import rustworkx as rx
2307
    ///
2308
    ///     graph = rx.generators.directed_path_graph(5)
2309
    ///     path = os.path.join(tempfile.gettempdir(), "edge_list")
2310
    ///     graph.write_edge_list(path, deliminator=',')
2311
    ///     # Print file contents
2312
    ///     with open(path, 'rt') as edge_file:
2313
    ///         print(edge_file.read())
2314
    ///
2315
    #[pyo3(text_signature = "(self, path, /, deliminator=None, weight_fn=None)", signature = (path, deliminator=None, weight_fn=None))]
2316
    pub fn write_edge_list(
10✔
2317
        &self,
10✔
2318
        py: Python,
10✔
2319
        path: &str,
10✔
2320
        deliminator: Option<char>,
10✔
2321
        weight_fn: Option<PyObject>,
10✔
2322
    ) -> PyResult<()> {
10✔
2323
        let file = File::create(path)?;
10✔
2324
        let mut buf_writer = BufWriter::new(file);
10✔
2325
        let delim = match deliminator {
10✔
2326
            Some(delim) => delim.to_string(),
2✔
2327
            None => " ".to_string(),
8✔
2328
        };
2329

2330
        for edge in self.graph.edge_references() {
20✔
2331
            buf_writer.write_all(
20✔
2332
                format!(
20✔
2333
                    "{}{}{}",
20✔
2334
                    edge.source().index(),
20✔
2335
                    delim,
20✔
2336
                    edge.target().index()
20✔
2337
                )
20✔
2338
                .as_bytes(),
20✔
2339
            )?;
20✔
2340
            match weight_callable(py, &weight_fn, edge.weight(), None as Option<String>)? {
20✔
2341
                Some(weight) => buf_writer.write_all(format!("{}{}\n", delim, weight).as_bytes()),
8✔
2342
                None => buf_writer.write_all(b"\n"),
8✔
2343
            }?;
×
2344
        }
2345
        buf_writer.flush()?;
6✔
2346
        Ok(())
6✔
2347
    }
10✔
2348

2349
    /// Create a new :class:`~rustworkx.PyDiGraph` object from an adjacency matrix
2350
    /// with matrix elements of type ``float``
2351
    ///
2352
    /// This method can be used to construct a new :class:`~rustworkx.PyDiGraph`
2353
    /// object from an input adjacency matrix. The node weights will be the
2354
    /// index from the matrix. The edge weights will be a float value of the
2355
    /// value from the matrix.
2356
    ///
2357
    /// This differs from the
2358
    /// :meth:`~rustworkx.PyDiGraph.from_complex_adjacency_matrix` in that the
2359
    /// type of the elements of input matrix must be a ``float`` (specifically
2360
    /// a ``numpy.float64``) and the output graph edge weights will be ``float``
2361
    /// too. While in :meth:`~rustworkx.PyDiGraph.from_complex_adjacency_matrix`
2362
    /// the matrix elements are of type ``complex`` (specifically
2363
    /// ``numpy.complex128``) and the edge weights in the output graph will be
2364
    /// ``complex`` too.
2365
    ///
2366
    /// :param ndarray matrix: The input numpy array adjacency matrix to create
2367
    ///     a new :class:`~rustworkx.PyDiGraph` object from. It must be a 2
2368
    ///     dimensional array and be a ``float``/``np.float64`` data type.
2369
    /// :param float null_value: An optional float that will treated as a null
2370
    ///     value. If any element in the input matrix is this value it will be
2371
    ///     treated as not an edge. By default this is ``0.0``
2372
    ///
2373
    /// :returns: A new graph object generated from the adjacency matrix
2374
    /// :rtype: PyDiGraph
2375
    #[staticmethod]
2376
    #[pyo3(signature=(matrix, null_value=0.0), text_signature = "(matrix, /, null_value=0.0)")]
2377
    pub fn from_adjacency_matrix<'p>(
14✔
2378
        py: Python<'p>,
14✔
2379
        matrix: PyReadonlyArray2<'p, f64>,
14✔
2380
        null_value: f64,
14✔
2381
    ) -> PyResult<PyDiGraph> {
14✔
2382
        _from_adjacency_matrix(py, matrix, null_value)
14✔
2383
    }
14✔
2384

2385
    /// Create a new :class:`~rustworkx.PyDiGraph` object from an adjacency matrix
2386
    /// with matrix elements of type ``complex``
2387
    ///
2388
    /// This method can be used to construct a new :class:`~rustworkx.PyDiGraph`
2389
    /// object from an input adjacency matrix. The node weights will be the
2390
    /// index from the matrix. The edge weights will be a complex value of the
2391
    /// value from the matrix.
2392
    ///
2393
    /// This differs from the
2394
    /// :meth:`~rustworkx.PyDiGraph.from_adjacency_matrix` in that the type of
2395
    /// the elements of the input matrix in this method must be a ``complex``
2396
    /// (specifically a ``numpy.complex128``) and the output graph edge weights
2397
    /// will be ``complex`` too. While in
2398
    /// :meth:`~rustworkx.PyDiGraph.from_adjacency_matrix` the matrix elements
2399
    /// are of type ``float`` (specifically ``numpy.float64``) and the edge
2400
    /// weights in the output graph will be ``float`` too.
2401
    ///
2402
    /// :param ndarray matrix: The input numpy array adjacency matrix to create
2403
    ///     a new :class:`~rustworkx.PyDiGraph` object from. It must be a 2
2404
    ///     dimensional array and be a ``complex``/``np.complex128`` data type.
2405
    /// :param complex null_value: An optional complex that will treated as a
2406
    ///     null value. If any element in the input matrix is this value it
2407
    ///     will be treated as not an edge. By default this is ``0.0+0.0j``
2408
    ///
2409
    /// :returns: A new graph object generated from the adjacency matrix
2410
    /// :rtype: PyDiGraph
2411
    #[staticmethod]
2412
    #[pyo3(signature=(matrix, null_value=Complex64::zero()), text_signature = "(matrix, /, null_value=0.0+0.0j)")]
2413
    pub fn from_complex_adjacency_matrix<'p>(
12✔
2414
        py: Python<'p>,
12✔
2415
        matrix: PyReadonlyArray2<'p, Complex64>,
12✔
2416
        null_value: Complex64,
12✔
2417
    ) -> PyResult<PyDiGraph> {
12✔
2418
        _from_adjacency_matrix(py, matrix, null_value)
12✔
2419
    }
12✔
2420

2421
    /// Add another PyDiGraph object into this PyDiGraph
2422
    ///
2423
    /// :param PyDiGraph other: The other PyDiGraph object to add onto this
2424
    ///     graph.
2425
    /// :param dict node_map: A dictionary mapping node indices from this
2426
    ///     PyDiGraph object to node indices in the other PyDiGraph object.
2427
    ///     The keys are a node index in this graph and the value is a tuple
2428
    ///     of the node index in the other graph to add an edge to and the
2429
    ///     weight of that edge. For example::
2430
    ///
2431
    ///         {
2432
    ///             1: (2, "weight"),
2433
    ///             2: (4, "weight2")
2434
    ///         }
2435
    ///
2436
    /// :param node_map_func: An optional python callable that will take in a
2437
    ///     single node weight/data object and return a new node weight/data
2438
    ///     object that will be used when adding an node from other onto this
2439
    ///     graph.
2440
    /// :param edge_map_func: An optional python callable that will take in a
2441
    ///     single edge weight/data object and return a new edge weight/data
2442
    ///     object that will be used when adding an edge from other onto this
2443
    ///     graph.
2444
    ///
2445
    /// :returns: new_node_ids: A dictionary mapping node index from the other
2446
    ///     PyDiGraph to the corresponding node index in this PyDAG after they've been
2447
    ///     combined
2448
    /// :rtype: dict
2449
    ///
2450
    /// For example, start by building a graph:
2451
    ///
2452
    /// .. jupyter-execute::
2453
    ///
2454
    ///   import rustworkx as rx
2455
    ///   from rustworkx.visualization import mpl_draw
2456
    ///
2457
    ///   # Build first graph and visualize:
2458
    ///   graph = rx.PyDiGraph()
2459
    ///   node_a = graph.add_node('A')
2460
    ///   node_b = graph.add_child(node_a, 'B', 'A to B')
2461
    ///   node_c = graph.add_child(node_b, 'C', 'B to C')
2462
    ///   mpl_draw(graph, with_labels=True, labels=str, edge_labels=str)
2463
    ///
2464
    /// Then build a second one:
2465
    ///
2466
    /// .. jupyter-execute::
2467
    ///
2468
    ///   # Build second graph and visualize:
2469
    ///   other_graph = rx.PyDiGraph()
2470
    ///   node_d = other_graph.add_node('D')
2471
    ///   other_graph.add_child(node_d, 'E', 'D to E')
2472
    ///   mpl_draw(other_graph, with_labels=True, labels=str, edge_labels=str)
2473
    ///
2474
    /// Finally compose the ``other_graph`` onto ``graph``
2475
    ///
2476
    /// .. jupyter-execute::
2477
    ///
2478
    ///   node_map = {node_b: (node_d, 'B to D')}
2479
    ///   graph.compose(other_graph, node_map)
2480
    ///   mpl_draw(graph, with_labels=True, labels=str, edge_labels=str)
2481
    ///
2482
    #[pyo3(text_signature = "(self, other, node_map, /, node_map_func=None, edge_map_func=None)", signature = (other, node_map, node_map_func=None, edge_map_func=None))]
2483
    pub fn compose(
6✔
2484
        &mut self,
6✔
2485
        py: Python,
6✔
2486
        other: &PyDiGraph,
6✔
2487
        node_map: HashMap<usize, (usize, PyObject)>,
6✔
2488
        node_map_func: Option<PyObject>,
6✔
2489
        edge_map_func: Option<PyObject>,
6✔
2490
    ) -> PyResult<PyObject> {
6✔
2491
        let mut new_node_map: DictMap<NodeIndex, NodeIndex> =
6✔
2492
            DictMap::with_capacity(other.node_count());
6✔
2493

2494
        // TODO: Reimplement this without looping over the graphs
2495
        // Loop over other nodes add add to self graph
2496
        for node in other.graph.node_indices() {
28✔
2497
            let new_index = self.graph.add_node(weight_transform_callable(
28✔
2498
                py,
28✔
2499
                &node_map_func,
28✔
2500
                &other.graph[node],
28✔
2501
            )?);
28✔
2502
            new_node_map.insert(node, new_index);
28✔
2503
        }
2504

2505
        // loop over other edges and add to self graph
2506
        for edge in other.graph.edge_references() {
30✔
2507
            let new_p_index = new_node_map.get(&edge.source()).unwrap();
30✔
2508
            let new_c_index = new_node_map.get(&edge.target()).unwrap();
30✔
2509
            let weight = weight_transform_callable(py, &edge_map_func, edge.weight())?;
30✔
2510
            self._add_edge(*new_p_index, *new_c_index, weight)?;
30✔
2511
        }
2512
        // Add edges from map
2513
        for (this_index, (index, weight)) in node_map.iter() {
6✔
2514
            let new_index = new_node_map.get(&NodeIndex::new(*index)).unwrap();
6✔
2515
            self._add_edge(
6✔
2516
                NodeIndex::new(*this_index),
6✔
2517
                *new_index,
6✔
2518
                weight.clone_ref(py),
6✔
2519
            )?;
6✔
2520
        }
2521
        let out_dict = PyDict::new(py);
6✔
2522
        for (orig_node, new_node) in new_node_map.iter() {
28✔
2523
            out_dict.set_item(orig_node.index(), new_node.index())?;
28✔
2524
        }
2525
        Ok(out_dict.into())
6✔
2526
    }
6✔
2527

2528
    /// Substitute a node with a PyDigraph object
2529
    ///
2530
    /// :param int node: The node to replace with the PyDiGraph object
2531
    /// :param PyDiGraph other: The other graph to replace ``node`` with
2532
    /// :param callable edge_map_fn: A callable object that will take 3 position
2533
    ///     parameters, ``(source, target, weight)`` to represent an edge either to
2534
    ///     or from ``node`` in this graph. The expected return value from this
2535
    ///     callable is the node index of the node in ``other`` that an edge should
2536
    ///     be to/from. If None is returned, that edge will be skipped and not
2537
    ///     be copied.
2538
    /// :param callable node_filter: An optional callable object that when used
2539
    ///     will receive a node's payload object from ``other`` and return
2540
    ///     ``True`` if that node is to be included in the graph or not.
2541
    /// :param callable edge_weight_map: An optional callable object that when
2542
    ///     used will receive an edge's weight/data payload from ``other`` and
2543
    ///     will return an object to use as the weight for a newly created edge
2544
    ///     after the edge is mapped from ``other``. If not specified the weight
2545
    ///     from the edge in ``other`` will be copied by reference and used.
2546
    ///
2547
    /// :returns: A mapping of node indices in ``other`` to the equivalent node
2548
    ///     in this graph.
2549
    /// :rtype: NodeMap
2550
    ///
2551
    /// .. note::
2552
    ///
2553
    ///    The return type is a :class:`rustworkx.NodeMap` which is an unordered
2554
    ///    type. So it does not provide a deterministic ordering between objects
2555
    ///    when iterated over (although the same object will have a consistent
2556
    ///    order when iterated over multiple times).
2557
    ///
2558
    #[pyo3(
2559
        text_signature = "(self, node, other, edge_map_fn, /, node_filter=None, edge_weight_map=None)",
2560
        signature = (node, other, edge_map_fn, node_filter=None, edge_weight_map=None)
2561
    )]
2562
    fn substitute_node_with_subgraph(
70✔
2563
        &mut self,
70✔
2564
        py: Python,
70✔
2565
        node: usize,
70✔
2566
        other: &PyDiGraph,
70✔
2567
        edge_map_fn: PyObject,
70✔
2568
        node_filter: Option<PyObject>,
70✔
2569
        edge_weight_map: Option<PyObject>,
70✔
2570
    ) -> PyResult<NodeMap> {
70✔
2571
        let weight_map_fn = |obj: &PyObject, weight_fn: &Option<PyObject>| -> PyResult<PyObject> {
356✔
2572
            match weight_fn {
356✔
2573
                Some(weight_fn) => weight_fn.call1(py, (obj,)),
4✔
2574
                None => Ok(obj.clone_ref(py)),
352✔
2575
            }
2576
        };
356✔
2577
        let map_fn = |source: usize, target: usize, weight: &PyObject| -> PyResult<Option<usize>> {
70✔
2578
            let res = edge_map_fn.call1(py, (source, target, weight))?;
46✔
2579
            res.extract(py)
46✔
2580
        };
46✔
2581
        let filter_fn = |obj: &PyObject, filter_fn: &Option<PyObject>| -> PyResult<bool> {
290✔
2582
            match filter_fn {
290✔
2583
                Some(filter) => {
10✔
2584
                    let res = filter.call1(py, (obj,))?;
10✔
2585
                    res.extract(py)
10✔
2586
                }
2587
                None => Ok(true),
280✔
2588
            }
2589
        };
290✔
2590
        let node_index: NodeIndex = NodeIndex::new(node);
70✔
2591
        if self.graph.node_weight(node_index).is_none() {
70✔
2592
            return Err(PyIndexError::new_err(format!(
2✔
2593
                "Specified node {} is not in this graph",
2✔
2594
                node
2✔
2595
            )));
2✔
2596
        }
68✔
2597
        // Copy nodes from other to self
68✔
2598
        let mut out_map: DictMap<usize, usize> = DictMap::with_capacity(other.node_count());
68✔
2599
        for node in other.graph.node_indices() {
290✔
2600
            let node_weight = other.graph[node].clone_ref(py);
290✔
2601
            if !filter_fn(&node_weight, &node_filter)? {
290✔
2602
                continue;
4✔
2603
            }
286✔
2604
            let new_index = self.graph.add_node(node_weight);
286✔
2605
            out_map.insert(node.index(), new_index.index());
286✔
2606
        }
2607
        // If no nodes are copied bail here since there is nothing left
2608
        // to do.
2609
        if out_map.is_empty() {
68✔
2610
            self.remove_node(node_index.index())?;
2✔
2611
            // Return a new empty map to clear allocation from out_map
2612
            return Ok(NodeMap {
2✔
2613
                node_map: DictMap::new(),
2✔
2614
            });
2✔
2615
        }
66✔
2616
        // Copy edges from other to self
2617
        for edge in other.graph.edge_references().filter(|edge| {
362✔
2618
            out_map.contains_key(&edge.target().index())
362✔
2619
                && out_map.contains_key(&edge.source().index())
356✔
2620
        }) {
362✔
2621
            self._add_edge(
356✔
2622
                NodeIndex::new(out_map[&edge.source().index()]),
356✔
2623
                NodeIndex::new(out_map[&edge.target().index()]),
356✔
2624
                weight_map_fn(edge.weight(), &edge_weight_map)?,
356✔
2625
            )?;
×
2626
        }
2627
        // Add edges to/from node to nodes in other
2628
        let in_edges: Vec<(NodeIndex, NodeIndex, PyObject)> = self
66✔
2629
            .graph
66✔
2630
            .edges_directed(node_index, petgraph::Direction::Incoming)
66✔
2631
            .map(|edge| (edge.source(), edge.target(), edge.weight().clone_ref(py)))
66✔
2632
            .collect();
66✔
2633
        let out_edges: Vec<(NodeIndex, NodeIndex, PyObject)> = self
66✔
2634
            .graph
66✔
2635
            .edges_directed(node_index, petgraph::Direction::Outgoing)
66✔
2636
            .map(|edge| (edge.source(), edge.target(), edge.weight().clone_ref(py)))
66✔
2637
            .collect();
66✔
2638
        for (source, target, weight) in in_edges {
78✔
2639
            let old_index = map_fn(source.index(), target.index(), &weight)?;
14✔
2640
            let target_out = match old_index {
14✔
2641
                Some(old_index) => match out_map.get(&old_index) {
12✔
2642
                    Some(new_index) => NodeIndex::new(*new_index),
10✔
2643
                    None => {
2644
                        return Err(PyIndexError::new_err(format!(
2✔
2645
                            "No mapped index {} found",
2✔
2646
                            old_index
2✔
2647
                        )))
2✔
2648
                    }
2649
                },
2650
                None => continue,
2✔
2651
            };
2652
            self._add_edge(source, target_out, weight)?;
10✔
2653
        }
2654
        for (source, target, weight) in out_edges {
92✔
2655
            let old_index = map_fn(source.index(), target.index(), &weight)?;
32✔
2656
            let source_out = match old_index {
32✔
2657
                Some(old_index) => match out_map.get(&old_index) {
30✔
2658
                    Some(new_index) => NodeIndex::new(*new_index),
26✔
2659
                    None => {
2660
                        return Err(PyIndexError::new_err(format!(
4✔
2661
                            "No mapped index {} found",
4✔
2662
                            old_index
4✔
2663
                        )))
4✔
2664
                    }
2665
                },
2666
                None => continue,
2✔
2667
            };
2668
            self._add_edge(source_out, target, weight)?;
26✔
2669
        }
2670
        // Remove node
2671
        self.remove_node(node_index.index())?;
60✔
2672
        Ok(NodeMap { node_map: out_map })
60✔
2673
    }
70✔
2674

2675
    /// Substitute a set of nodes with a single new node.
2676
    ///
2677
    /// :param list nodes: A set of nodes to be removed and replaced
2678
    ///     by the new node. Any nodes not in the graph are ignored.
2679
    ///     If empty, this method behaves like :meth:`~PyDiGraph.add_node`
2680
    ///     (but slower).
2681
    /// :param object obj: The data/weight to associate with the new node.
2682
    /// :param bool check_cycle: If set to ``True``, validates
2683
    ///     that the contraction will not introduce cycles before
2684
    ///     modifying the graph. If set to ``False``, validation is
2685
    ///     skipped. If not provided, inherits the value
2686
    ///     of ``check_cycle`` from this instance of
2687
    ///     :class:`~rustworkx.PyDiGraph`.
2688
    /// :param weight_combo_fn: An optional python callable that, when
2689
    ///     specified, is used to merge parallel edges introduced by the
2690
    ///     contraction, which will occur when multiple nodes in
2691
    ///     ``nodes`` have an incoming edge
2692
    ///     from the same source node or when multiple nodes in
2693
    ///     ``nodes`` have an outgoing edge to the same target node.
2694
    ///     If this instance of :class:`~rustworkx.PyDiGraph` is a multigraph,
2695
    ///     leave this unspecified to preserve parallel edges. If unspecified
2696
    ///     when not a multigraph, parallel edges and their weights will be
2697
    ///     combined by choosing one of the edge's weights arbitrarily based
2698
    ///     on an internal iteration order, subject to change.
2699
    /// :returns: The index of the newly created node.
2700
    /// :raises DAGWouldCycle: The cycle check is enabled and the
2701
    ///     contraction would introduce cycle(s).
2702
    #[pyo3(text_signature = "(self, nodes, obj, /, check_cycle=None, weight_combo_fn=None)", signature = (nodes, obj, check_cycle=None, weight_combo_fn=None))]
2703
    pub fn contract_nodes(
32✔
2704
        &mut self,
32✔
2705
        py: Python,
32✔
2706
        nodes: Vec<usize>,
32✔
2707
        obj: PyObject,
32✔
2708
        check_cycle: Option<bool>,
32✔
2709
        weight_combo_fn: Option<PyObject>,
32✔
2710
    ) -> RxPyResult<usize> {
32✔
2711
        let nodes = nodes.into_iter().map(|i| NodeIndex::new(i));
74✔
2712
        let check_cycle = check_cycle.unwrap_or(self.check_cycle);
32✔
2713
        let res = match (weight_combo_fn, &self.multigraph) {
32✔
2714
            (Some(user_callback), _) => {
2✔
2715
                self.graph
2✔
2716
                    .contract_nodes_simple(nodes, obj, check_cycle, |w1, w2| {
8✔
2717
                        user_callback.call1(py, (w1, w2))
8✔
2718
                    })?
8✔
2719
            }
2720
            (None, false) => {
2721
                // By default, just take first edge.
2722
                self.graph
4✔
2723
                    .contract_nodes_simple(nodes, obj, check_cycle, move |w1, _| {
8✔
2724
                        Ok::<_, PyErr>(w1.clone_ref(py))
8✔
2725
                    })?
8✔
2726
            }
2727
            (None, true) => self.graph.contract_nodes(nodes, obj, check_cycle)?,
26✔
2728
        };
2729
        Ok(res.index())
22✔
2730
    }
32✔
2731

2732
    /// Return a new PyDiGraph object for a subgraph of this graph
2733
    ///
2734
    /// :param list nodes: A list of node indices to generate the subgraph
2735
    ///     from. If a node index is included that is not present in the graph
2736
    ///     it will silently be ignored.
2737
    /// :param preserve_attrs: If set to the True the attributes of the PyDiGraph
2738
    ///     will be copied by reference to be the attributes of the output
2739
    ///     subgraph. By default this is set to False and the :attr:`~.PyDiGraph.attrs`
2740
    ///     attribute will be ``None`` in the subgraph.
2741
    ///
2742
    /// :returns: A new PyDiGraph object representing a subgraph of this graph.
2743
    ///     It is worth noting that node and edge weight/data payloads are
2744
    ///     passed by reference so if you update (not replace) an object used
2745
    ///     as the weight in graph or the subgraph it will also be updated in
2746
    ///     the other.
2747
    /// :rtype: PyGraph
2748
    ///
2749
    #[pyo3(signature=(nodes, preserve_attrs=false),text_signature = "(self, nodes, /, preserve_attrs=False)")]
2750
    pub fn subgraph(&self, py: Python, nodes: Vec<usize>, preserve_attrs: bool) -> PyDiGraph {
10✔
2751
        let node_set: HashSet<usize> = nodes.iter().cloned().collect();
10✔
2752
        let mut node_map: HashMap<NodeIndex, NodeIndex> = HashMap::with_capacity(nodes.len());
10✔
2753
        let node_filter = |node: NodeIndex| -> bool { node_set.contains(&node.index()) };
98✔
2754
        let mut out_graph = StablePyGraph::<Directed>::new();
10✔
2755
        let filtered = NodeFiltered(&self.graph, node_filter);
10✔
2756
        for node in filtered.node_references() {
16✔
2757
            let new_node = out_graph.add_node(node.1.clone_ref(py));
16✔
2758
            node_map.insert(node.0, new_node);
16✔
2759
        }
16✔
2760
        for edge in filtered.edge_references() {
14✔
2761
            let new_source = *node_map.get(&edge.source()).unwrap();
14✔
2762
            let new_target = *node_map.get(&edge.target()).unwrap();
14✔
2763
            out_graph.add_edge(new_source, new_target, edge.weight().clone_ref(py));
14✔
2764
        }
14✔
2765
        let attrs = if preserve_attrs {
10✔
2766
            self.attrs.clone_ref(py)
×
2767
        } else {
2768
            py.None()
10✔
2769
        };
2770
        PyDiGraph {
10✔
2771
            graph: out_graph,
10✔
2772
            node_removed: false,
10✔
2773
            cycle_state: algo::DfsSpace::default(),
10✔
2774
            check_cycle: self.check_cycle,
10✔
2775
            multigraph: self.multigraph,
10✔
2776
            attrs,
10✔
2777
        }
10✔
2778
    }
10✔
2779

2780
    /// Return a new PyDiGraph object for an edge induced subgraph of this graph
2781
    ///
2782
    /// The induced subgraph contains each edge in `edge_list` and each node
2783
    /// incident to any of those edges.
2784
    ///
2785
    /// :param list edge_list: A list of edge tuples (2-tuples with the source and
2786
    ///     target node) to generate the subgraph from. In cases of parallel
2787
    ///     edges for a multigraph all edges between the specified node. In case
2788
    ///     of an edge specified that doesn't exist in the graph it will be
2789
    ///     silently ignored.
2790
    ///
2791
    /// :returns: The edge subgraph
2792
    /// :rtype: PyDiGraph
2793
    ///
2794
    #[pyo3(text_signature = "(self, edge_list, /)")]
2795
    pub fn edge_subgraph(&self, edge_list: Vec<[usize; 2]>) -> PyDiGraph {
8✔
2796
        // Filter non-existent edges
8✔
2797
        let edges: Vec<[usize; 2]> = edge_list
8✔
2798
            .into_iter()
8✔
2799
            .filter(|x| {
14✔
2800
                let source = NodeIndex::new(x[0]);
14✔
2801
                let target = NodeIndex::new(x[1]);
14✔
2802
                self.graph.find_edge(source, target).is_some()
14✔
2803
            })
14✔
2804
            .collect();
8✔
2805

8✔
2806
        let nodes: HashSet<NodeIndex> = edges
8✔
2807
            .iter()
8✔
2808
            .flat_map(|x| x.iter())
12✔
2809
            .copied()
8✔
2810
            .map(NodeIndex::new)
8✔
2811
            .collect();
8✔
2812
        let mut edge_set: HashSet<[NodeIndex; 2]> = HashSet::with_capacity(edges.len());
8✔
2813
        for edge in edges {
20✔
2814
            let source_index = NodeIndex::new(edge[0]);
12✔
2815
            let target_index = NodeIndex::new(edge[1]);
12✔
2816
            edge_set.insert([source_index, target_index]);
12✔
2817
        }
12✔
2818
        let mut out_graph = self.clone();
8✔
2819
        for node in self
14✔
2820
            .graph
8✔
2821
            .node_indices()
8✔
2822
            .filter(|node| !nodes.contains(node))
32✔
2823
        {
14✔
2824
            out_graph.graph.remove_node(node);
14✔
2825
            out_graph.node_removed = true;
14✔
2826
        }
14✔
2827
        for edge in self
28✔
2828
            .graph
8✔
2829
            .edge_references()
8✔
2830
            .filter(|edge| !edge_set.contains(&[edge.source(), edge.target()]))
44✔
2831
        {
28✔
2832
            out_graph.graph.remove_edge(edge.id());
28✔
2833
        }
28✔
2834
        out_graph
8✔
2835
    }
8✔
2836

2837
    /// Check if the graph is symmetric
2838
    ///
2839
    /// :returns: True if the graph is symmetric
2840
    /// :rtype: bool
2841
    #[pyo3(text_signature = "(self)")]
2842
    pub fn is_symmetric(&self) -> bool {
4✔
2843
        let mut edges: HashSet<(NodeIndex, NodeIndex)> = HashSet::new();
4✔
2844
        for (source, target) in self
20✔
2845
            .graph
4✔
2846
            .edge_references()
4✔
2847
            .map(|edge| (edge.source(), edge.target()))
20✔
2848
        {
2849
            let edge = (source, target);
20✔
2850
            let reversed = (target, source);
20✔
2851
            if edges.contains(&reversed) {
20✔
2852
                edges.remove(&reversed);
8✔
2853
            } else {
12✔
2854
                edges.insert(edge);
12✔
2855
            }
12✔
2856
        }
2857
        edges.is_empty()
4✔
2858
    }
4✔
2859

2860
    /// Make edges in graph symmetric
2861
    ///
2862
    /// This function iterates over all the edges in the graph, adding for each
2863
    /// edge the reversed edge, unless one is already present. Note the edge insertion
2864
    /// is not fixed and the edge indices are not guaranteed to be consistent
2865
    /// between executions of this method on identical graphs.
2866
    ///
2867
    /// :param callable edge_payload: This optional argument takes in a callable which will
2868
    ///     be passed a single positional argument the data payload for an edge that will
2869
    ///     have a reverse copied in the graph. The returned value from this callable will
2870
    ///     be used as the data payload for the new edge created. If this is not specified
2871
    ///     then by default the data payload will be copied when the reverse edge is added.
2872
    ///     If there are parallel edges, then one of the edges (typically the one with the lower
2873
    ///     index, but this is not a guarantee) will be copied.
2874
    #[pyo3(signature = (edge_payload_fn=None))]
2875
    pub fn make_symmetric(
14✔
2876
        &mut self,
14✔
2877
        py: Python,
14✔
2878
        edge_payload_fn: Option<PyObject>,
14✔
2879
    ) -> PyResult<()> {
14✔
2880
        let edges: HashMap<[NodeIndex; 2], EdgeIndex> = self
14✔
2881
            .graph
14✔
2882
            .edge_references()
14✔
2883
            .map(|edge| ([edge.source(), edge.target()], edge.id()))
38✔
2884
            .collect();
14✔
2885
        for ([edge_source, edge_target], edge_index) in edges.iter() {
34✔
2886
            if !edges.contains_key(&[*edge_target, *edge_source]) {
34✔
2887
                let forward_weight = self.graph.edge_weight(*edge_index).unwrap();
18✔
2888
                let weight: PyObject = match edge_payload_fn.as_ref() {
18✔
2889
                    Some(callback) => callback.call1(py, (forward_weight,))?,
10✔
2890
                    None => forward_weight.clone_ref(py),
8✔
2891
                };
2892
                self._add_edge(*edge_target, *edge_source, weight)?;
16✔
2893
            }
16✔
2894
        }
2895
        Ok(())
12✔
2896
    }
14✔
2897

2898
    /// Generate a new PyGraph object from this graph
2899
    ///
2900
    /// This will create a new :class:`~rustworkx.PyGraph` object from this
2901
    /// graph. All edges in this graph will be created as undirected edges in
2902
    /// the new graph object. For directed graphs with bidirectional edges, you
2903
    /// can set `multigraph=False` to condense them into a single edge and specify
2904
    /// a function to combine the weights/data of the edges.
2905
    /// Do note that the node and edge weights/data payloads will be passed
2906
    /// by reference to the new :class:`~rustworkx.PyGraph` object.
2907
    ///
2908
    /// .. note::
2909
    ///
2910
    ///     The node indices in the output :class:`~rustworkx.PyGraph` may
2911
    ///     differ if nodes have been removed.
2912
    ///
2913
    /// :param bool multigraph: If set to `False` the output graph will not
2914
    ///     allow parallel edges. Instead parallel edges will be condensed
2915
    ///     into a single edge and their data will be combined using
2916
    ///     `weight_combo_fn`. If `weight_combo_fn` is not provided, the data
2917
    ///     of the edge with the largest index will be kept. Default: `True`.
2918
    /// :param weight_combo_fn: An optional python callable that will take in a
2919
    ///     two edge weight/data object and return a new edge weight/data
2920
    ///     object that will be used when adding an edge between two nodes
2921
    ///     connected by multiple edges (of either direction) in the original
2922
    ///     directed graph.
2923
    /// :returns: A new PyGraph object with an undirected edge for every
2924
    ///     directed edge in this graph
2925
    /// :rtype: PyGraph
2926
    #[pyo3(signature=(multigraph=true, weight_combo_fn=None), text_signature = "(self, /, multigraph=True, weight_combo_fn=None)")]
2927
    pub fn to_undirected(
28✔
2928
        &self,
28✔
2929
        py: Python,
28✔
2930
        multigraph: bool,
28✔
2931
        weight_combo_fn: Option<PyObject>,
28✔
2932
    ) -> PyResult<crate::graph::PyGraph> {
28✔
2933
        let node_count = self.node_count();
28✔
2934
        let mut new_graph = if multigraph {
28✔
2935
            StablePyGraph::<Undirected>::with_capacity(node_count, self.graph.edge_count())
24✔
2936
        } else {
2937
            // If multigraph is false edge count is difficult to predict
2938
            // without counting parallel edges. So, just stick with 0 and
2939
            // reallocate dynamically
2940
            StablePyGraph::<Undirected>::with_capacity(node_count, 0)
4✔
2941
        };
2942

2943
        let mut node_map: HashMap<NodeIndex, NodeIndex> = HashMap::with_capacity(node_count);
28✔
2944

28✔
2945
        let combine = |a: &PyObject,
28✔
2946
                       b: &PyObject,
2947
                       combo_fn: &Option<PyObject>|
2948
         -> PyResult<Option<PyObject>> {
10✔
2949
            match combo_fn {
10✔
2950
                Some(combo_fn) => {
2✔
2951
                    let res = combo_fn.call1(py, (a, b))?;
2✔
2952
                    Ok(Some(res))
2✔
2953
                }
2954
                None => Ok(None),
8✔
2955
            }
2956
        };
10✔
2957

2958
        for node_index in self.graph.node_indices() {
134✔
2959
            let node = self.graph[node_index].clone_ref(py);
134✔
2960
            let new_index = new_graph.add_node(node);
134✔
2961
            node_map.insert(node_index, new_index);
134✔
2962
        }
134✔
2963
        for edge in self.graph.edge_references() {
176✔
2964
            let &source = node_map.get(&edge.source()).unwrap();
176✔
2965
            let &target = node_map.get(&edge.target()).unwrap();
176✔
2966
            let weight = edge.weight().clone_ref(py);
176✔
2967
            if multigraph {
176✔
2968
                new_graph.add_edge(source, target, weight);
156✔
2969
            } else {
156✔
2970
                let exists = new_graph.find_edge(source, target);
20✔
2971
                match exists {
20✔
2972
                    Some(index) => {
10✔
2973
                        let old_weight = new_graph.edge_weight_mut(index).unwrap();
10✔
2974
                        match combine(old_weight, edge.weight(), &weight_combo_fn)? {
10✔
2975
                            Some(value) => {
2✔
2976
                                *old_weight = value;
2✔
2977
                            }
2✔
2978
                            None => {
8✔
2979
                                *old_weight = weight;
8✔
2980
                            }
8✔
2981
                        }
2982
                    }
2983
                    None => {
10✔
2984
                        new_graph.add_edge(source, target, weight);
10✔
2985
                    }
10✔
2986
                }
2987
            }
2988
        }
2989
        Ok(crate::graph::PyGraph {
28✔
2990
            graph: new_graph,
28✔
2991
            node_removed: false,
28✔
2992
            multigraph,
28✔
2993
            attrs: py.None(),
28✔
2994
        })
28✔
2995
    }
28✔
2996

2997
    /// Return a shallow copy of the graph
2998
    ///
2999
    /// All node and edge weight/data payloads in the copy will have a
3000
    /// shared reference to the original graph.
3001
    #[pyo3(text_signature = "(self)")]
3002
    pub fn copy(&self) -> PyDiGraph {
10✔
3003
        self.clone()
10✔
3004
    }
10✔
3005

3006
    /// Reverse the direction of all edges in the graph, in place.
3007
    ///
3008
    /// This method modifies the graph instance to reverse the direction of all edges.
3009
    /// It does so by iterating over all edges in the graph and removing each edge,
3010
    /// then adding a new edge in the opposite direction with the same weight.
3011
    ///
3012
    /// For Example::
3013
    ///
3014
    ///     import rustworkx as rx
3015
    ///
3016
    ///     graph = rx.PyDiGraph()
3017
    ///
3018
    ///     # Generate a path directed path graph with weights
3019
    ///     graph.extend_from_weighted_edge_list([
3020
    ///         (0, 1, 3),
3021
    ///         (1, 2, 5),
3022
    ///         (2, 3, 2),
3023
    ///     ])
3024
    ///     # Reverse edges
3025
    ///     graph.reverse()
3026
    ///
3027
    ///     assert graph.weighted_edge_list() == [(3, 2, 2), (2, 1, 5), (1, 0, 3)];
3028
    #[pyo3(text_signature = "(self)")]
3029
    pub fn reverse(&mut self, py: Python) {
18✔
3030
        let indices = self.graph.edge_indices().collect::<Vec<EdgeIndex>>();
18✔
3031
        for idx in indices {
2,000,106✔
3032
            let (source_node, dest_node) = self.graph.edge_endpoints(idx).unwrap();
2,000,088✔
3033
            let weight = self.graph.edge_weight(idx).unwrap().clone_ref(py);
2,000,088✔
3034
            self.graph.remove_edge(idx);
2,000,088✔
3035
            self.graph.add_edge(dest_node, source_node, weight);
2,000,088✔
3036
        }
2,000,088✔
3037
    }
18✔
3038

3039
    /// Filters a graph's nodes by some criteria conditioned on a node's data payload and returns those nodes' indices.
3040
    ///
3041
    /// This function takes in a function as an argument. This filter function will be passed in a node's data payload and is
3042
    /// required to return a boolean value stating whether the node's data payload fits some criteria.
3043
    ///
3044
    /// For example::
3045
    ///
3046
    ///     from rustworkx import PyDiGraph
3047
    ///
3048
    ///     graph = PyDiGraph()
3049
    ///     graph.add_nodes_from(list(range(5)))
3050
    ///
3051
    ///     def my_filter_function(node):
3052
    ///         return node > 2
3053
    ///
3054
    ///     indices = graph.filter_nodes(my_filter_function)
3055
    ///     assert indices == [3, 4]
3056
    ///
3057
    /// :param filter_function: Function with which to filter nodes
3058
    /// :returns: The node indices that match the filter
3059
    /// :rtype: NodeIndices
3060
    #[pyo3(text_signature = "(self, filter_function)")]
3061
    pub fn filter_nodes(&self, py: Python, filter_function: PyObject) -> PyResult<NodeIndices> {
8✔
3062
        let filter = |nindex: NodeIndex| -> PyResult<bool> {
32✔
3063
            let res = filter_function.call1(py, (&self.graph[nindex],))?;
32✔
3064
            res.extract(py)
30✔
3065
        };
32✔
3066

3067
        let mut n = Vec::with_capacity(self.graph.node_count());
8✔
3068
        for node_index in self.graph.node_indices() {
32✔
3069
            if filter(node_index)? {
32✔
3070
                n.push(node_index.index())
8✔
3071
            };
22✔
3072
        }
3073
        Ok(NodeIndices { nodes: n })
6✔
3074
    }
8✔
3075

3076
    /// Filters a graph's edges by some criteria conditioned on a edge's data payload and returns those edges' indices.
3077
    ///
3078
    /// This function takes in a function as an argument. This filter function will be passed in an edge's data payload and is
3079
    /// required to return a boolean value stating whether the edge's data payload fits some criteria.
3080
    ///
3081
    /// For example::
3082
    ///
3083
    ///     from rustworkx import PyGraph
3084
    ///     from rustworkx.generators import complete_graph
3085
    ///
3086
    ///     graph = PyGraph()
3087
    ///     graph.add_nodes_from(range(3))
3088
    ///     graph.add_edges_from([(0, 1, 'A'), (0, 1, 'B'), (1, 2, 'C')])
3089
    ///
3090
    ///     def my_filter_function(edge):
3091
    ///         if edge:
3092
    ///             return edge == 'B'
3093
    ///         return False
3094
    ///
3095
    ///     indices = graph.filter_edges(my_filter_function)
3096
    ///     assert indices == [1]
3097
    ///
3098
    /// :param filter_function: Function with which to filter edges
3099
    /// :returns: The edge indices that match the filter
3100
    /// :rtype: EdgeIndices
3101
    #[pyo3(text_signature = "(self, filter_function)")]
3102
    pub fn filter_edges(&self, py: Python, filter_function: PyObject) -> PyResult<EdgeIndices> {
8✔
3103
        let filter = |eindex: EdgeIndex| -> PyResult<bool> {
20✔
3104
            let res = filter_function.call1(py, (&self.graph[eindex],))?;
20✔
3105
            res.extract(py)
18✔
3106
        };
20✔
3107

3108
        let mut e = Vec::with_capacity(self.graph.edge_count());
8✔
3109
        for edge_index in self.graph.edge_indices() {
20✔
3110
            if filter(edge_index)? {
20✔
3111
                e.push(edge_index.index())
6✔
3112
            };
12✔
3113
        }
3114
        Ok(EdgeIndices { edges: e })
6✔
3115
    }
8✔
3116

3117
    /// Return the number of nodes in the graph
3118
    fn __len__(&self) -> PyResult<usize> {
208✔
3119
        Ok(self.graph.node_count())
208✔
3120
    }
208✔
3121

3122
    fn __getitem__(&self, idx: usize) -> PyResult<&PyObject> {
52✔
3123
        match self.graph.node_weight(NodeIndex::new(idx)) {
52✔
3124
            Some(data) => Ok(data),
50✔
3125
            None => Err(PyIndexError::new_err("No node found for index")),
2✔
3126
        }
3127
    }
52✔
3128

3129
    fn __setitem__(&mut self, idx: usize, value: PyObject) -> PyResult<()> {
3,770✔
3130
        let data = match self.graph.node_weight_mut(NodeIndex::new(idx)) {
3,770✔
3131
            Some(node_data) => node_data,
3,768✔
3132
            None => return Err(PyIndexError::new_err("No node found for index")),
2✔
3133
        };
3134
        *data = value;
3,768✔
3135
        Ok(())
3,768✔
3136
    }
3,770✔
3137

3138
    fn __delitem__(&mut self, idx: usize) -> PyResult<()> {
4✔
3139
        match self.graph.remove_node(NodeIndex::new(idx)) {
4✔
3140
            Some(_) => {
3141
                self.node_removed = true;
2✔
3142
                Ok(())
2✔
3143
            }
3144
            None => Err(PyIndexError::new_err("No node found for index")),
2✔
3145
        }
3146
    }
4✔
3147

3148
    #[classmethod]
3149
    #[pyo3(signature = (key, /))]
3150
    pub fn __class_getitem__(
4✔
3151
        cls: &Bound<'_, PyType>,
4✔
3152
        key: &Bound<'_, PyAny>,
4✔
3153
    ) -> PyResult<PyObject> {
4✔
3154
        generic_class_getitem(cls, key)
4✔
3155
    }
4✔
3156

3157
    // Functions to enable Python Garbage Collection
3158

3159
    // Function for PyTypeObject.tp_traverse [1][2] used to tell Python what
3160
    // objects the PyDiGraph has strong references to.
3161
    //
3162
    // [1] https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse
3163
    // [2] https://pyo3.rs/v0.12.4/class/protocols.html#garbage-collector-integration
3164
    fn __traverse__(&self, visit: PyVisit) -> Result<(), PyTraverseError> {
128✔
3165
        for node in self
14,682,996✔
3166
            .graph
128✔
3167
            .node_indices()
128✔
3168
            .map(|node| self.graph.node_weight(node).unwrap())
14,682,996✔
3169
        {
3170
            visit.call(node)?;
14,682,996✔
3171
        }
3172
        for edge in self
1,356,912✔
3173
            .graph
128✔
3174
            .edge_indices()
128✔
3175
            .map(|edge| self.graph.edge_weight(edge).unwrap())
1,356,912✔
3176
        {
3177
            visit.call(edge)?;
1,356,912✔
3178
        }
3179
        visit.call(&self.attrs)?;
128✔
3180
        Ok(())
128✔
3181
    }
128✔
3182

3183
    // Function for PyTypeObject.tp_clear [1][2] used to tell Python's GC how
3184
    // to drop all references held by a PyDiGraph object when the GC needs to
3185
    // break reference cycles.
3186
    //
3187
    // ]1] https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_clear
3188
    // [2] https://pyo3.rs/v0.12.4/class/protocols.html#garbage-collector-integration
3189
    fn __clear__(&mut self, py: Python) {
×
3190
        self.graph = StablePyGraph::<Directed>::new();
×
3191
        self.node_removed = false;
×
3192
        self.attrs = py.None();
×
3193
    }
×
3194
}
3195

3196
fn is_cycle_check_required(dag: &PyDiGraph, a: NodeIndex, b: NodeIndex) -> bool {
26✔
3197
    let mut parents_a = dag
26✔
3198
        .graph
26✔
3199
        .neighbors_directed(a, petgraph::Direction::Incoming);
26✔
3200
    let mut children_b = dag
26✔
3201
        .graph
26✔
3202
        .neighbors_directed(b, petgraph::Direction::Outgoing);
26✔
3203
    parents_a.next().is_some() && children_b.next().is_some() && dag.graph.find_edge(a, b).is_none()
26✔
3204
}
26✔
3205

3206
fn weight_transform_callable(
58✔
3207
    py: Python,
58✔
3208
    map_fn: &Option<PyObject>,
58✔
3209
    value: &PyObject,
58✔
3210
) -> PyResult<PyObject> {
58✔
3211
    match map_fn {
58✔
3212
        Some(map_fn) => {
10✔
3213
            let res = map_fn.call1(py, (value,))?;
10✔
3214
            res.into_py_any(py)
10✔
3215
        }
3216
        None => Ok(value.clone_ref(py)),
48✔
3217
    }
3218
}
58✔
3219

3220
fn _from_adjacency_matrix<'p, T>(
26✔
3221
    py: Python<'p>,
26✔
3222
    matrix: PyReadonlyArray2<'p, T>,
26✔
3223
    null_value: T,
26✔
3224
) -> PyResult<PyDiGraph>
26✔
3225
where
26✔
3226
    T: Copy + std::cmp::PartialEq + numpy::Element + pyo3::IntoPyObject<'p> + IsNan,
26✔
3227
{
26✔
3228
    let array = matrix.as_array();
26✔
3229
    let shape = array.shape();
26✔
3230
    let mut out_graph = StablePyGraph::<Directed>::new();
26✔
3231
    let _node_indices: Vec<NodeIndex> = (0..shape[0])
26✔
3232
        .map(|node| Ok(out_graph.add_node(node.into_py_any(py)?)))
272✔
3233
        .collect::<PyResult<Vec<NodeIndex>>>()?;
26✔
3234
    for (index, row) in array.axis_iter(Axis(0)).enumerate() {
272✔
3235
        let source_index = NodeIndex::new(index);
272✔
3236
        for (target_index, elem) in row.iter().enumerate() {
20,216✔
3237
            if null_value.is_nan() {
20,216✔
3238
                if !elem.is_nan() {
36✔
3239
                    out_graph.add_edge(
16✔
3240
                        source_index,
16✔
3241
                        NodeIndex::new(target_index),
16✔
3242
                        elem.into_py_any(py)?,
16✔
3243
                    );
3244
                }
20✔
3245
            } else if *elem != null_value {
20,180✔
3246
                out_graph.add_edge(
18,882✔
3247
                    source_index,
18,882✔
3248
                    NodeIndex::new(target_index),
18,882✔
3249
                    elem.into_py_any(py)?,
18,882✔
3250
                );
3251
            }
1,298✔
3252
        }
3253
    }
3254
    Ok(PyDiGraph {
26✔
3255
        graph: out_graph,
26✔
3256
        cycle_state: algo::DfsSpace::default(),
26✔
3257
        check_cycle: false,
26✔
3258
        node_removed: false,
26✔
3259
        multigraph: true,
26✔
3260
        attrs: py.None(),
26✔
3261
    })
26✔
3262
}
26✔
3263

3264
/// Simple wrapper newtype that lets us use `Py` pointers as hash keys with the equality defined by
3265
/// the pointer address.  This is equivalent to using Python's `is` operator for comparisons.
3266
/// Using a newtype rather than casting the pointer to `usize` inline lets us retrieve a copy of
3267
/// the reference from the key entry.
3268
struct PyAnyId(Py<PyAny>);
3269
impl PyAnyId {
3270
    fn clone_ref(&self, py: Python) -> Py<PyAny> {
30✔
3271
        self.0.clone_ref(py)
30✔
3272
    }
30✔
3273
}
3274
impl ::std::hash::Hash for PyAnyId {
3275
    fn hash<H: ::std::hash::Hasher>(&self, state: &mut H) {
86✔
3276
        (self.0.as_ptr() as usize).hash(state)
86✔
3277
    }
86✔
3278
}
3279
impl PartialEq for PyAnyId {
3280
    fn eq(&self, other: &Self) -> bool {
28✔
3281
        self.0.as_ptr() == other.0.as_ptr()
28✔
3282
    }
28✔
3283
}
3284
impl Eq for PyAnyId {}
3285

3286
/// Internal-only helper class used by `remove_node_retain_edges_by_key` to store its data as a
3287
/// typed object in a Python dictionary.
3288
#[pyclass]
24✔
3289
struct RemoveNodeEdgeValue {
3290
    weight: Py<PyAny>,
3291
    nodes: Vec<NodeIndex>,
3292
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc