• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

rmcar17 / cogent3 / 19318630917

10 Nov 2025 08:08PM UTC coverage: 90.631% (+0.008%) from 90.623%
19318630917

push

github

web-flow
Merge pull request #2518 from cogent3/dependabot/pip/ruff-0.14.4

Bump ruff from 0.14.3 to 0.14.4

28277 of 31200 relevant lines covered (90.63%)

5.44 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.66
/src/cogent3/core/tree.py
1
"""Classes for storing and manipulating a phylogenetic tree.
2

3
These trees can be either strictly binary, or have polytomies
4
(multiple children to a parent node).
5

6
Trees consist of Nodes (or branches) that connect two nodes. The Tree can
7
be created only from a newick formatted string read either from file or from a
8
string object. Other formats will be added as time permits.
9

10
Tree can:
11
    -  Deal with either rooted or unrooted tree's and can
12
       convert between these types.
13
    -  Return a sub-tree given a list of tip-names
14
    -  Identify an edge given two tip names. This method facilitates the
15
       statistical modelling by simplyifying the syntax for specifying
16
       sub-regions of a tree.
17
    -  Assess whether two Tree instances represent the same topology.
18

19
Definition of relevant terms or abbreviations:
20
    -  edge: also known as a branch on a tree.
21
    -  node: the point at which two edges meet
22
    -  tip: a sequence or species
23
    -  clade: all and only the nodes (including tips) that descend
24
       from a node
25
    -  stem: the edge immediately preceeding a clade
26
"""
27

28
from __future__ import annotations
6✔
29

30
import contextlib
6✔
31
import json
6✔
32
import random
6✔
33
import re
6✔
34
import warnings
6✔
35
from copy import deepcopy
6✔
36
from functools import reduce
6✔
37
from itertools import combinations
6✔
38
from operator import or_
6✔
39
from typing import (
6✔
40
    TYPE_CHECKING,
41
    Any,
42
    Literal,
43
    SupportsIndex,
44
    TypeVar,
45
    cast,
46
    overload,
47
)
48

49
import numpy
6✔
50
import numpy.typing as npt
6✔
51

52
from cogent3._version import __version__
6✔
53
from cogent3.parse.cogent3_json import load_from_json
6✔
54
from cogent3.parse.newick import parse_string as newick_parse_string
6✔
55
from cogent3.phylo.tree_distance import get_tree_distance_measure
6✔
56
from cogent3.util.deserialise import register_deserialiser
6✔
57
from cogent3.util.io import atomic_write, get_format_suffixes, open_
6✔
58
from cogent3.util.misc import get_object_provenance
6✔
59

60
if TYPE_CHECKING:  # pragma: no cover
61
    import os
62
    import pathlib
63
    from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
64
    from typing import Self
65

66
    from cogent3.draw.dendrogram import Dendrogram
67
    from cogent3.evolve.fast_distance import DistanceMatrix
68

69
    PySeq = Sequence
70
    PySeqStr = PySeq[str]
71

72

73
class TreeError(Exception):
6✔
74
    pass
6✔
75

76

77
def _format_node_name(
6✔
78
    node: PhyloNode,
79
    with_node_names: bool,
80
    escape_name: bool,
81
    with_distances: bool,
82
    with_root_name: bool = False,
83
) -> str:
84
    """Helper function to format node name according to parameters"""
85
    if (node.is_root() and not with_root_name) or (
6✔
86
        not node.is_tip() and not with_node_names
87
    ):
88
        node_name = ""
6✔
89
    else:
90
        node_name = node.name or ""
6✔
91

92
    if (
6✔
93
        node_name
94
        and escape_name
95
        and not (node_name.startswith("'") and node_name.endswith("'"))
96
    ):
97
        if re.search("""[]['"(),:;_]""", node_name):
6✔
98
            node_name = "'{}'".format(node_name.replace("'", "''"))
6✔
99
        else:
100
            node_name = node_name.replace(" ", "_")
6✔
101

102
    if with_distances and (length := node.length) is not None:
6✔
103
        node_name = f"{node_name}:{length}"
6✔
104

105
    return node_name
6✔
106

107

108
class PhyloNode:
6✔
109
    """Store information about a tree node. Mutable.
110

111
    Parameters:
112
        name: label for the node, assumed to be unique.
113
        children: list of the node's children.
114
        parent: parent to this node
115
        params: dict containing arbitrary parameters for the node.
116
        name_loaded: ?
117
    """
118

119
    __slots__ = (
6✔
120
        "_parent",
121
        "children",
122
        "length",
123
        "name",
124
        "name_loaded",
125
        "params",
126
        "support",
127
    )
128

129
    _exclude_from_copy = frozenset(["_parent", "children"])
6✔
130

131
    def __init__(
6✔
132
        self,
133
        name: str,
134
        children: Iterable[Self | str] | None = None,
135
        parent: Self | None = None,
136
        params: dict[str, Any] | None = None,
137
        name_loaded: bool = True,
138
        length: float | None = None,
139
        support: float | None = None,
140
    ) -> None:
141
        """Returns new PhyloNode object."""
142
        self.name = name
6✔
143
        self.name_loaded = name_loaded
6✔
144
        self.params = params or {}
6✔
145
        self.children: list[Self] = []
6✔
146
        if children:
6✔
147
            self.extend(children)
6✔
148

149
        self._parent = parent
6✔
150
        if parent is not None and self not in parent.children:
6✔
151
            parent.append(self)
6✔
152

153
        if "length" in self.params:
6✔
154
            if length is not None and self.params["length"] != length:
6✔
155
                msg = "Got two different lengths in params and as an argument."
×
156
                raise ValueError(msg)
×
157

158
            length = self.params.pop("length")
6✔
159

160
        self.length = length
6✔
161
        self.support = support
6✔
162

163
    # built-in methods and list interface support
164
    def __repr__(self) -> str:
6✔
165
        """Returns reconstructable string representation of tree.
166

167
        WARNING: Does not currently set the class to the right type.
168
        """
169
        return f'Tree("{self.get_newick()}")'
6✔
170

171
    def __str__(self) -> str:
6✔
172
        """Returns Newick-format string representation of tree."""
173
        return self.get_newick(with_distances=True)
6✔
174

175
    # TODO have methods that need to rely on identity of self and
176
    # other actually do that
177
    # For now, the following comparison operators are peculiar in that
178
    # that by omitting eq/ne methods those default to id()
179
    # whereas sorting should be based on name
180
    # I think the remove etc .. operations should explicitly
181
    # used id()
182
    # def __eq__(self, other):
183
    # return self.name == other.name
184

185
    # def __ne__(self, other):
186
    # return self.name != other.name
187

188
    def __lt__(self, other: PhyloNode) -> bool:
6✔
189
        return self.name < other.name
6✔
190

191
    def __gt__(self, other: PhyloNode) -> bool:
6✔
192
        return self.name > other.name
6✔
193

194
    @property
6✔
195
    def source(self) -> str | None:
6✔
196
        return self.params.get("source")
6✔
197

198
    @source.setter
6✔
199
    def source(self, value: str | None) -> None:
6✔
200
        """Sets the source of the node."""
201
        if value:
6✔
202
            self.params["source"] = value
6✔
203
        else:
204
            self.params.pop("source", None)
6✔
205

206
    def compare_name(self, other: PhyloNode) -> bool:
6✔
207
        """Compares PhyloNode by name"""
208
        return self is other or self.name == other.name
6✔
209

210
    def compare_by_names(self, other: PhyloNode) -> bool:
6✔
211
        """Equality test for trees by name"""
212
        # if they are the same object then they must be the same tree...
213
        if self is other:
6✔
214
            return True
6✔
215
        self_names = self.get_node_names()
6✔
216
        other_names = other.get_node_names()
6✔
217
        self_names = [v for v in self_names if v]
6✔
218
        self_names.sort()
6✔
219
        other_names = [v for v in other_names if v]
6✔
220
        other_names.sort()
6✔
221
        return self_names == other_names
6✔
222

223
    def _to_self_child(self, i: Self | str) -> Self:
6✔
224
        """Converts i to self's type, with self as its parent.
225

226
        Cleans up refs from i's original parent, but doesn't give self ref to i.
227
        """
228
        if isinstance(i, str):
6✔
229
            node = self.__class__(i)
6✔
230
        else:
231
            if i._parent is not None and i._parent is not self:
6✔
232
                i._parent.children.remove(i)
6✔
233
            node = i
6✔
234
        node._parent = self
6✔
235
        return node
6✔
236

237
    def append(self, i: Self | str) -> None:
6✔
238
        """Appends i to self.children, in-place, cleaning up refs."""
239
        self.children.append(self._to_self_child(i))
6✔
240

241
    def extend(self, items: Iterable[Self | str]) -> None:
6✔
242
        """Extends self.children by items, in-place, cleaning up refs."""
243
        self.children.extend(self._to_self_child(item) for item in items)
6✔
244

245
    def insert(self, index: SupportsIndex, i: Self | str) -> None:
6✔
246
        """Inserts an item at specified position in self.children."""
247
        self.children.insert(index, self._to_self_child(i))
6✔
248

249
    def pop(self, index: SupportsIndex = -1) -> Self:
6✔
250
        """Returns and deletes child of self at index (default: -1)"""
251
        result = self.children.pop(index)
6✔
252
        result._parent = None
6✔
253
        return result
6✔
254

255
    def remove(self, target: Self | str) -> bool:
6✔
256
        """Removes node by name instead of identity.
257

258
        Returns True if node was present, False otherwise.
259
        """
260
        if isinstance(target, PhyloNode):
6✔
261
            target = target.name
6✔
262
        for _i, curr_node in enumerate(self.children):
6✔
263
            if curr_node.name == target:
6✔
264
                self.remove_node(curr_node)
6✔
265
                return True
6✔
266
        return False
×
267

268
    @overload
269
    def __getitem__(self, i: SupportsIndex) -> Self: ...
270
    @overload
271
    def __getitem__(self, i: slice) -> list[Self]: ...
272

273
    def __getitem__(self, i: SupportsIndex | slice) -> Self | list[Self]:
6✔
274
        return self.children[i]
6✔
275

276
    @overload
277
    def __setitem__(self, i: SupportsIndex, value: Self | str) -> None: ...
278
    @overload
279
    def __setitem__(self, i: slice, value: Iterable[Self | str]) -> None: ...
280

281
    def __setitem__(
6✔
282
        self, i: SupportsIndex | slice, value: Self | str | Iterable[Self | str]
283
    ) -> None:
284
        """Node[i] = x sets the corresponding item in children."""
285
        if isinstance(i, slice):
6✔
286
            nodes = cast("Iterable[Self]", value)
6✔
287

288
            current_children = self.children[i]
6✔
289

290
            for child in current_children:
6✔
291
                child._parent = None
6✔
292

293
            self.children[i] = [self._to_self_child(node) for node in nodes]
6✔
294
        else:
295
            node = cast("Self", value)
6✔
296

297
            current_child = self.children[i]
6✔
298

299
            current_child._parent = None
6✔
300
            self.children[i] = self._to_self_child(node)
6✔
301

302
    def __delitem__(self, i: SupportsIndex | slice) -> None:
6✔
303
        """del node[i] deletes index or slice from self.children."""
304
        if isinstance(i, slice):
6✔
305
            for child in self.children[i]:
6✔
306
                child._parent = None
6✔
307
        else:
308
            self.children[i]._parent = None
6✔
309
        del self.children[i]
6✔
310

311
    def __iter__(self) -> Iterator[Self]:
6✔
312
        """Node iter iterates over the children."""
313
        return iter(self.children)
6✔
314

315
    def __len__(self) -> int:
6✔
316
        """Node len returns number of children."""
317
        return len(self.children)
6✔
318

319
    @classmethod
6✔
320
    def _copy_node(cls, node: Self, memo: dict[int, Any] | None = None) -> Self:
6✔
321
        result = cls(node.name)
6✔
322
        efc = node._exclude_from_copy
6✔
323
        for k in node.__slots__:
6✔
324
            if k not in efc:
6✔
325
                setattr(result, k, deepcopy(getattr(node, k), memo=memo))
6✔
326
        return result
6✔
327

328
    def copy(self, memo: dict[int, Any] | None = None) -> Self:
6✔
329
        """Returns a copy of self using an iterative approach"""
330
        if memo is None:
6✔
331
            memo = {}
6✔
332

333
        obj_id = id(self)
6✔
334
        if obj_id in memo:
6✔
335
            return memo[obj_id]
×
336

337
        root = self.__class__._copy_node(self)
6✔
338
        nodes_stack = [(root, self, len(self.children))]
6✔
339

340
        while nodes_stack:
6✔
341
            # check the top node, any children left unvisited?
342
            top = nodes_stack[-1]
6✔
343
            new_top_node, old_top_node, unvisited_children = top
6✔
344

345
            if unvisited_children:
6✔
346
                nodes_stack[-1] = (new_top_node, old_top_node, unvisited_children - 1)
6✔
347
                old_child = old_top_node.children[-unvisited_children]
6✔
348
                new_child = self.__class__._copy_node(old_child)
6✔
349
                new_top_node.append(new_child)
6✔
350
                nodes_stack.append((new_child, old_child, len(old_child.children)))
6✔
351
            else:  # no unvisited children
352
                nodes_stack.pop()
6✔
353
        return root
6✔
354

355
    __deepcopy__ = deepcopy = copy
6✔
356

357
    # support for basic tree operations -- finding objects and moving in the
358
    # tree
359
    @property
6✔
360
    def parent(self) -> Self | None:
6✔
361
        """parent of this node"""
362
        return self._parent
6✔
363

364
    @parent.setter
6✔
365
    def parent(self, parent: Self | None) -> None:
6✔
366
        """parent of this node"""
367
        if self._parent is not None:
6✔
368
            self._parent.remove_node(self)
6✔
369
        self._parent = parent
6✔
370
        if parent is not None and self not in parent.children:
6✔
371
            parent.children.append(self)
6✔
372

373
    def index_in_parent(self: Self) -> int:
6✔
374
        """Returns index of self in parent."""
375
        if self._parent is None:
6✔
376
            msg = "Node has no parent."
×
377
            raise TreeError(msg)
×
378
        return self._parent.children.index(self)
6✔
379

380
    def is_tip(self) -> bool:
6✔
381
        """Returns True if the current node is a tip, i.e. has no children."""
382
        return not self.children
6✔
383

384
    def is_root(self) -> bool:
6✔
385
        """Returns True if the current is a root, i.e. has no parent."""
386
        return self._parent is None
6✔
387

388
    def levelorder(self, include_self: bool = True) -> Generator[Self]:
6✔
389
        """Performs levelorder iteration over tree"""
390
        queue = [self]
6✔
391
        while queue:
6✔
392
            curr = queue.pop(0)
6✔
393
            if include_self or (curr is not self):
6✔
394
                yield curr
6✔
395
            if curr.children:
6✔
396
                queue.extend(curr.children)
6✔
397

398
    def preorder(self, include_self: bool = True) -> Generator[Self]:
6✔
399
        """Performs preorder iteration over tree."""
400
        stack = [self]
6✔
401

402
        while stack:
6✔
403
            node = stack.pop()
6✔
404
            if include_self or node is not self:
6✔
405
                yield node
6✔
406

407
            # the stack is last-in-first-out, so we add children
408
            # in reverse order so they're processed left-to-right
409
            if node.children:
6✔
410
                stack.extend(node.children[::-1])
6✔
411

412
    def postorder(self, include_self: bool = True) -> Generator[Self]:
6✔
413
        """performs postorder iteration over tree"""
414
        stack = [(self, False)]
6✔
415

416
        while stack:
6✔
417
            node, children_done = stack.pop()
6✔
418
            if children_done:
6✔
419
                if include_self or node is not self:
6✔
420
                    yield node
6✔
421
            else:
422
                # children still need to be processed
423
                stack.append((node, True))
6✔
424

425
                # the stack is last-in-first-out, so we add children
426
                # in reverse order so they're processed left-to-right
427
                if node.children:
6✔
428
                    stack.extend((child, False) for child in node.children[::-1])
6✔
429

430
    def pre_and_postorder(self, include_self: bool = True) -> Generator[Self]:
6✔
431
        """Performs iteration over tree, visiting node before and after."""
432
        yield from self.preorder(include_self=include_self)
×
433
        yield from self.postorder(include_self=include_self)
×
434

435
    def ancestors(self) -> list[Self]:
6✔
436
        """Returns all ancestors back to the root."""
437
        result: list[Self] = []
6✔
438
        curr = self._parent
6✔
439
        while curr is not None:
6✔
440
            result.append(curr)
6✔
441
            curr = curr._parent
6✔
442
        return result
6✔
443

444
    def get_root(self) -> Self:
6✔
445
        """Returns root of the tree self is in."""
446
        curr = self
6✔
447
        while curr._parent is not None:
6✔
448
            curr = curr._parent
6✔
449
        return curr
6✔
450

451
    def rooted(self, edge_name: str) -> Self:
6✔
452
        """Returns a new tree with split at edge_name
453

454
        Parameters
455
        ----------
456
        edge_name
457
            name of the edge to split at. The length of edge_name will be
458
            halved. The new tree will have two children.
459
        """
460
        tree = self.deepcopy()
6✔
461
        if not self.is_root():
6✔
462
            msg = (
6✔
463
                f"cannot apply from non-root node {self.name!r}, "
464
                "use self.get_root() first"
465
            )
466
            raise TreeError(msg)
6✔
467

468
        if edge_name == "root":
6✔
469
            return tree
6✔
470

471
        tree.source = None
6✔
472
        node = tree.get_node_matching_name(edge_name)
6✔
473
        is_tip = node.is_tip()
6✔
474
        has_length = any(node.length is not None for node in self.preorder())
6✔
475
        # we put tips on the right
476
        right_name = edge_name if is_tip else f"{edge_name}-R"
6✔
477
        left_name = f"{edge_name}-root" if is_tip else f"{edge_name}-L"
6✔
478
        length = (node.length or 0) / 2
6✔
479
        parent = cast("Self", node.parent)
6✔
480
        parent.children.remove(node)
6✔
481
        node.parent = None
6✔
482
        left = node.unrooted_deepcopy()
6✔
483
        right = parent.unrooted_deepcopy()
6✔
484
        if is_tip and left.is_tip():
6✔
485
            left.name = right_name
6✔
486
            right.name = left_name
6✔
487
        else:
488
            left.name = left_name
6✔
489
            right.name = right_name
6✔
490

491
        if has_length:
6✔
492
            left.length = length
6✔
493
            right.length = length
6✔
494

495
        result = self.__class__(name="root", children=[left, right])
6✔
496
        result.source = self.source
6✔
497
        result.prune()
6✔
498
        return result
6✔
499

500
    def isroot(self) -> bool:
6✔
501
        """Returns True if root of a tree, i.e. no parent."""
502
        return self.is_root()
6✔
503

504
    def siblings(self) -> list[Self]:
6✔
505
        """Returns all nodes that are children of the same parent as self.
506

507
        Note: excludes self from the list. Dynamically calculated.
508
        """
509
        if self._parent is None:
6✔
510
            return []
6✔
511
        result = self._parent.children[:]
6✔
512
        result.remove(self)
6✔
513
        return result
6✔
514

515
    def iter_tips(self, include_self: bool = False) -> Generator[Self]:
6✔
516
        """Iterates over tips descended from self, [] if self is a tip."""
517
        # bail out in easy case
518
        if not self.children:
6✔
519
            if include_self:
6✔
520
                yield self
6✔
521
            return None
6✔
522
        # use stack-based method: robust to large trees
523
        stack = [self]
6✔
524
        while stack:
6✔
525
            curr = stack.pop()
6✔
526
            if curr.children:
6✔
527
                stack.extend(curr.children[::-1])  # 20% faster than reversed
6✔
528
            else:
529
                yield curr
6✔
530

531
    def tips(self, include_self: bool = False) -> list[Self]:
6✔
532
        """Returns tips descended from self, [] if self is a tip."""
533
        return list(self.iter_tips(include_self=include_self))
6✔
534

535
    def iter_nontips(self, include_self: bool = False) -> Generator[Self]:
6✔
536
        """Iterates over nontips descended from self
537

538
        Parameters
539
        ----------
540
        include_self
541
            if True (default is False), will return the current
542
            node as part of the list of nontips if it is a nontip.
543
        """
544
        for n in self.preorder(include_self=include_self):
6✔
545
            if n.children:
6✔
546
                yield n
6✔
547

548
    def nontips(self, include_self: bool = False) -> list[Self]:
6✔
549
        """Returns nontips descended from self."""
550
        return list(self.iter_nontips(include_self=include_self))
6✔
551

552
    def istip(self) -> bool:
6✔
553
        """Returns True if is tip, i.e. no children."""
554
        return not self.children
6✔
555

556
    def tip_children(self) -> list[Self]:
6✔
557
        """Returns direct children of self that are tips."""
558
        return [i for i in self.children if not i.children]
6✔
559

560
    def non_tip_children(self) -> list[Self]:
6✔
561
        """Returns direct children in self that have descendants."""
562
        return [i for i in self.children if i.children]
6✔
563

564
    def last_common_ancestor(self, other: Self) -> Self:
6✔
565
        """Finds last common ancestor of self and other, or None.
566

567
        Always tests by identity.
568
        """
569
        my_lineage = {id(node) for node in [self, *self.ancestors()]}
6✔
570
        curr: Self | None = other
6✔
571
        while curr is not None:
6✔
572
            if id(curr) in my_lineage:
6✔
573
                return curr
6✔
574
            curr = curr._parent
6✔
575
        msg = "No common ancestor found."
6✔
576
        raise TreeError(msg)
6✔
577

578
    def lowest_common_ancestor(self, tip_names: list[str]) -> Self:
6✔
579
        """Lowest common ancestor for a list of tipnames
580

581
        This should be around O(H sqrt(n)), where H is height and n is the
582
        number of tips passed in.
583
        """
584
        if len(tip_names) == 1:
6✔
585
            return self.get_node_matching_name(tip_names[0])
6✔
586

587
        tip_names_set: set[str] = set(tip_names)
6✔
588
        tips = [tip for tip in self.tips() if tip.name in tip_names_set]
6✔
589

590
        if len(tips) != len(tip_names_set):
6✔
591
            missing = tip_names_set - set(self.get_tip_names())
6✔
592
            msg = f"tipnames {missing} not present in self"
6✔
593
            raise ValueError(msg)
6✔
594

595
        # scrub tree
596
        for n in self.preorder(include_self=True):
6✔
597
            n.params.pop("black", None)
6✔
598

599
        for t in tips:
6✔
600
            prev = t
6✔
601
            curr = t.parent
6✔
602

603
            while curr and "black" not in curr.params:
6✔
604
                curr.params["black"] = [prev]
6✔
605
                prev = curr
6✔
606
                curr = curr.parent
6✔
607

608
            # increase black count, multiple children lead to here
609
            if curr:
6✔
610
                curr.params["black"].append(prev)
6✔
611

612
        curr = self
6✔
613
        while len(curr.params.get("black", [])) == 1:
6✔
614
            curr = curr.params.pop("black")[0]
6✔
615

616
        return curr
6✔
617

618
    lca = last_common_ancestor  # for convenience
6✔
619

620
    # support for more advanced tree operations
621

622
    def separation(self, other: Self | None) -> int:
6✔
623
        """Returns number of edges separating self and other."""
624
        # detect trivial case
625
        if self is other:
6✔
626
            return 0
6✔
627
        # otherwise, check the list of ancestors
628
        my_ancestors = dict.fromkeys(list(map(id, [self, *self.ancestors()])))
6✔
629
        count = 0
6✔
630
        while other is not None:
6✔
631
            if id(other) in my_ancestors:
6✔
632
                # need to figure out how many steps there were back from self
633
                curr: Self | None = self
6✔
634
                while curr is not None and curr is not other:
6✔
635
                    count += 1
6✔
636
                    curr = curr.parent
6✔
637
                return count
6✔
638
            count += 1
6✔
639
            other = other.parent
6✔
640
        msg = "Nodes do not belong to the same tree."
×
641
        raise TreeError(msg)
×
642

643
    def descendant_array(
6✔
644
        self, tip_list: list[str] | None = None
645
    ) -> tuple[npt.NDArray[numpy.bool], list[Self]]:
646
        """Returns numpy array with nodes in rows and descendants in columns.
647

648
        True indicates that the decendant is a descendant of that node
649
        False indicates that it is not
650

651
        Also returns a list of nodes in the same order as they are listed
652
        in the array.
653

654
        tip_list is a list of the names of the tips that will be considered,
655
        in the order they will appear as columns in the final array. Internal
656
        nodes will appear as rows in preorder traversal order.
657
        """
658

659
        # get a list of internal nodes
660
        node_list = [node for node in self.preorder() if node.children]
×
661
        node_list.sort()
×
662

663
        # get a list of tip names if one is not supplied
664
        if not tip_list:
×
665
            tip_list = self.get_tip_names()
×
666
            tip_list.sort()
×
667
        # make a blank array of the right dimensions to alter
668
        result = numpy.zeros([len(node_list), len(tip_list)], dtype=numpy.bool)
×
669
        # put 1 in the column for each child of each node
670
        for i, node in enumerate(node_list):
×
671
            children = [n.name for n in node.tips()]
×
672
            for j, dec in enumerate(tip_list):
×
673
                if dec in children:
×
674
                    result[i, j] = 1
×
675
        return result, node_list
×
676

677
    def _default_tree_constructor(
6✔
678
        self,
679
    ) -> Callable[[Self | None, PySeq[Self], dict[str, Any] | None], Self]:
680
        return cast(
6✔
681
            "Callable[[Self | None, PySeq[Self], dict[str, Any] | None], Self]",
682
            TreeBuilder(constructor=self.__class__).edge_from_edge,
683
        )
684

685
    def name_unnamed_nodes(self) -> None:
6✔
686
        """sets the Data property of unnamed nodes to an arbitrary value
687

688
        Internal nodes are often unnamed and so this function assigns a
689
        Internal nodes are often unnamed and so this function assigns a
690
        value for referencing."""
691
        # make a list of the names that are already in the tree
692
        names_in_use = [node.name for node in self.preorder() if node.name]
6✔
693
        # assign unique names to the Data property of nodes where Data = None
694
        name_index = 1
6✔
695
        for node in self.preorder():
6✔
696
            if not node.name:
6✔
697
                new_name = f"node{name_index!s}"
6✔
698
                # choose a new name if name is already in tree
699
                while new_name in names_in_use:
6✔
700
                    name_index += 1
6✔
701
                    new_name = f"node{name_index}"
6✔
702
                node.name = new_name
6✔
703
                names_in_use.append(new_name)
6✔
704
                name_index += 1
6✔
705

706
    def make_tree_array(
6✔
707
        self, dec_list: list[str] | None = None
708
    ) -> tuple[npt.NDArray[numpy.uint8], list[Self]]:
709
        """Makes an array with nodes in rows and descendants in columns.
710

711
        A value of 1 indicates that the decendant is a descendant of that node/
712
        A value of 0 indicates that it is not
713

714
        also returns a list of nodes in the same order as they are listed
715
        in the array"""
716
        # get a list of internal nodes
717
        node_list = [node for node in self.preorder() if node.children]
6✔
718
        node_list.sort()
6✔
719

720
        # get a list of tips() name if one is not supplied
721
        if not dec_list:
6✔
722
            dec_list = self.get_tip_names()
6✔
723
            dec_list.sort()
6✔
724
        # make a blank array of the right dimensions to alter
725
        result = numpy.zeros((len(node_list), len(dec_list)), dtype=numpy.uint8)
6✔
726
        # put 1 in the column for each child of each node
727
        for i, node in enumerate(node_list):
6✔
728
            children = [dec.name for dec in node.tips()]
6✔
729
            for j, dec in enumerate(dec_list):
6✔
730
                if dec in children:
6✔
731
                    result[i, j] = 1
6✔
732
        return result, node_list
6✔
733

734
    def remove_deleted(self, should_delete: Callable[[Self], bool]) -> None:
6✔
735
        """Removes all nodes where should_delete tests true.
736

737
        Internal nodes that have no children as a result of removing deleted
738
        are also removed.
739
        """
740
        # Traverse tree
741
        for node in self.postorder():
6✔
742
            # if node is to be deleted
743
            if should_delete(node):
6✔
744
                # Store current parent
745
                curr_parent = node.parent
6✔
746
                # Set current node's parent to None (this deletes node)
747
                node.parent = None
6✔
748
                # While there are no chilren at node and not at root
749
                while (curr_parent is not None) and (not curr_parent.children):
6✔
750
                    # Save old parent
751
                    old_parent = curr_parent
6✔
752
                    # Get new parent
753
                    curr_parent = curr_parent.parent
6✔
754
                    # remove old node from tree
755
                    old_parent.parent = None
6✔
756

757
    def prune(
6✔
758
        self: Self,
759
        keep_root: bool = False,
760
        params_merge_callback: Callable[
761
            [dict[str, Any], dict[str, Any]], dict[str, Any]
762
        ]
763
        | None = None,
764
    ) -> None:
765
        """removes nodes with one child
766

767
        Parameters
768
        ----------
769
        keep_root
770
            If True, a root with a single child is retained.
771
        params_merge_callback
772
            How to merge two params dicts when pruning.
773
            The first argument is the parent node's params,
774
            the second argument is the child node's params.
775
            It should return the new params dictionary.
776

777
        Notes
778
        -----
779
        Mutates the tree in-place. Internal nodes with only one child will be
780
        merged (except as specified by keep_root).
781
        """
782
        while True:
6✔
783
            nodes_to_remove = [
6✔
784
                n
785
                for n in self.iter_nontips()
786
                if n.parent is not None and len(n.children) == 1
787
            ]
788
            if not nodes_to_remove:
6✔
789
                break
6✔
790

791
            for node in nodes_to_remove:
6✔
792
                curr_parent = node.parent
6✔
793
                child = node.children[0]
6✔
794
                node.parent = None
6✔
795
                child.parent = curr_parent
6✔
796

797
                if node.length is not None:
6✔
798
                    child.length = (child.length or 0) + node.length
6✔
799

800
                # For support, we keep the child's support. Do nothing.
801

802
                if params_merge_callback is not None:
6✔
803
                    child.params = params_merge_callback(node.params, child.params)
×
804

805
        # root having one child is edge case
806
        if not keep_root and len(self.children) == 1:
6✔
807
            child = self.children[0]
6✔
808

809
            grand_children = list(child.children)
6✔
810

811
            # Ignore the child's length as a root can't have a length
812
            # Unclear what should be done about support as well, keep current support.
813

814
            if params_merge_callback is not None:
6✔
815
                self.params = params_merge_callback(self.params, child.params)
×
816

817
            self.remove_node(child)
6✔
818
            for grand_child in grand_children:
6✔
819
                grand_child.parent = self
6✔
820

821
    def same_shape(self, other: Self) -> bool:
6✔
822
        """Ignores lengths and order, so trees should be sorted first"""
823
        if len(self.children) != len(other.children):
6✔
824
            return False
×
825
        if self.children:
6✔
826
            for self_child, other_child in zip(
6✔
827
                self.children,
828
                other.children,
829
                strict=False,
830
            ):
831
                if not self_child.same_shape(other_child):
6✔
832
                    return False
6✔
833
            return True
6✔
834
        return self.name == other.name
6✔
835

836
    def to_rich_dict(self) -> dict[str, Any]:
6✔
837
        """returns {'newick': with node names,
838
        'edge_attributes': {'tip1': {'length': ...}, ...}}"""
839
        newick = self.get_newick(
6✔
840
            with_node_names=True,
841
            semicolon=False,
842
            escape_name=False,
843
            with_root_name=True,
844
        )
845
        attr = {}
6✔
846
        length_and_support = {}
6✔
847
        for edge in self.get_edge_vector(include_root=True):
6✔
848
            attr[edge.name] = edge.params.copy()
6✔
849
            length_and_support[edge.name] = {
6✔
850
                "length": edge.length,
851
                "support": edge.support,
852
            }
853
        return {
6✔
854
            "newick": newick,
855
            "edge_attributes": attr,
856
            "length_and_support": length_and_support,
857
            "type": get_object_provenance(self),
858
            "version": __version__,
859
        }
860

861
    def to_json(self) -> str:
6✔
862
        """returns json formatted string {'newick': with edges and distances, 'edge_attributes': }"""
863
        return json.dumps(self.to_rich_dict())
6✔
864

865
    def get_newick(
6✔
866
        self,
867
        with_distances: bool = False,
868
        semicolon: bool = True,
869
        escape_name: bool = True,
870
        with_node_names: bool = False,
871
        with_root_name: bool = False,
872
    ) -> str:
873
        """Return the newick string of node and its descendents
874

875
        Parameters
876
        ----------
877
        with_distances
878
            include value of node length attribute if present.
879
        semicolon
880
            end tree string with a semicolon
881
        escape_name
882
            if any of these characters []'"() are within the
883
            nodes name, wrap the name in single quotes
884
        with_node_names
885
            includes internal node names
886
        with_root_name
887
            if True and with_node_names, the root node will have
888
            its name included
889
        """
890
        # Stack contains tuples of (tree node, visit flag)
891
        stack = [(self, False)]
6✔
892
        node_results: dict[int, str] = {}  # results cache
6✔
893

894
        while stack:
6✔
895
            node, visited = stack.pop()
6✔
896

897
            if not visited:
6✔
898
                # First visit - push back for processing after children
899
                stack.append((node, True))
6✔
900
                # add each child to the stack
901
                stack.extend((child, False) for child in node.children)
6✔
902
            else:
903
                # children have been seen once
904
                node_name = _format_node_name(
6✔
905
                    node,
906
                    with_node_names=with_node_names,
907
                    escape_name=escape_name,
908
                    with_distances=with_distances,
909
                    with_root_name=with_root_name,
910
                )
911

912
                # for tips with parent, the typical case
913
                if node.is_tip() and node.parent:
6✔
914
                    node_results[id(node)] = node_name
6✔
915
                    continue
6✔
916

917
                # collecting children
918
                # Build result for this node
919
                if children_newick := [
6✔
920
                    node_results[id(child)] for child in node.children
921
                ]:
922
                    result = f"({','.join(children_newick)}){node_name}"
6✔
923
                else:
924
                    result = node_name
6✔
925

926
                node_results[id(node)] = result
6✔
927

928
        # final result
929
        final_result = node_results[id(self)]
6✔
930

931
        if self.is_root() and semicolon:
6✔
932
            final_result = f"{final_result};"
6✔
933

934
        return final_result
6✔
935

936
    def remove_node(self, target: Self) -> bool:
6✔
937
        """Removes node by identity instead of value.
938

939
        Returns True if node was present, False otherwise.
940
        """
941
        for i, curr_node in enumerate(self.children):
6✔
942
            if curr_node is target:
6✔
943
                del self[i]
6✔
944
                return True
6✔
945
        return False
6✔
946

947
    def get_edge_names(
6✔
948
        self,
949
        tip_name_1: str,
950
        tip_name_2: str,
951
        clade: bool = True,
952
        stem: bool = False,
953
        outgroup_name: str | None = None,
954
    ) -> list[str]:
955
        """Return the list of stem and/or sub tree (clade) edge name(s).
956
        This is done by finding the common intersection, and then getting
957
        the list of names. If the clade traverses the root, then use the
958
        outgroup_name argument to ensure valid specification.
959

960
        Parameters
961
        ----------
962
        tip_name_1/2
963
            edge 1/2 names
964
        stem
965
            whether the name of the clade stem edge is returned.
966
        clade
967
            whether the names of the edges within the clade are
968
            returned
969
        outgroup_name
970
            if provided the calculation is done on a version of
971
            the tree re-rooted relative to the provided tip.
972

973
        Usage:
974
            The returned list can be used to specify subtrees for special
975
            parameterisation. For instance, say you want to allow the primates
976
            to have a different value of a particular parameter. In this case,
977
            provide the results of this method to the parameter controller
978
            method `set_param_rule()` along with the parameter name etc..
979
        """
980
        # If outgroup specified put it at the top of the tree so that clades are
981
        # defined by their distance from it.  This makes a temporary tree with
982
        # a named edge at it's root, but it's only used here then discarded.
983
        root = self
6✔
984
        if outgroup_name is not None:
6✔
985
            outgroup = self.get_node_matching_name(outgroup_name)
6✔
986
            if not outgroup.is_tip():
6✔
987
                msg = f"Outgroup ({outgroup_name!r}) is not a tip"
×
988
                raise TreeError(msg)
×
989
            root = outgroup.unrooted_deepcopy()
6✔
990

991
        join_edge = root.get_connecting_node(tip_name_1, tip_name_2)
6✔
992

993
        edge_names: list[str] = []
6✔
994

995
        if stem:
6✔
996
            if join_edge.isroot():
6✔
997
                msg = f"LCA({tip_name_1},{tip_name_2}) is the root and so has no stem"
×
998
                raise TreeError(
×
999
                    msg,
1000
                )
1001
            edge_names.append(join_edge.name)
6✔
1002

1003
        if clade:
6✔
1004
            # get the list of names contained by join_edge
1005
            for child in join_edge.children:
6✔
1006
                branch_names = child.get_node_names(include_self=True)
6✔
1007
                edge_names.extend(branch_names)
6✔
1008

1009
        return edge_names
6✔
1010

1011
    def get_neighbours_except(self, parent: Self | None = None) -> list[Self]:
6✔
1012
        # For walking the tree as if it was unrooted.
1013
        return [
6✔
1014
            c
1015
            for c in (*self.children, self.parent)
1016
            if c is not None and c is not parent
1017
        ]
1018

1019
    def get_sub_tree(
6✔
1020
        self,
1021
        names: Iterable[str],
1022
        ignore_missing: bool = False,
1023
        tips_only: bool = False,
1024
        as_rooted: bool = False,
1025
    ) -> Self:
1026
        """A new instance of a sub tree that contains all the otus that are
1027
        listed in name_list.
1028

1029
        Parameters
1030
        ----------
1031
        ignore_missing
1032
            if False, get_sub_tree will raise a ValueError if
1033
            name_list contains names that aren't nodes in the tree
1034
        tips_only
1035
            only tip names matching name_list are allowed
1036
        as_rooted
1037
            if True, the resulting subtree root will be as resolved. Otherwise,
1038
            the subtree is coerced to have the same number of children as self.
1039
        """
1040
        # find all the selected nodes
1041
        allowed = set(names)
6✔
1042
        old_nodes: dict[int, Self] = {}
6✔
1043
        found: set[str] = set()
6✔
1044
        for old_node in self.preorder(include_self=True):
6✔
1045
            if old_node.name not in allowed:
6✔
1046
                continue
6✔
1047

1048
            found.add(old_node.name)
6✔
1049
            old_nodes[id(old_node)] = old_node
6✔
1050
            # find all nodes connecting required nodes to root,
1051
            # skipping if already present
1052
            parent = old_node.parent
6✔
1053
            while parent is not None and (parent_id := id(parent)) not in old_nodes:
6✔
1054
                old_nodes[parent_id] = parent
6✔
1055
                parent = parent.parent
6✔
1056

1057
            if not tips_only and not old_node.is_tip():
6✔
1058
                # add all descendant nodes too
1059
                for n in old_node.preorder():
6✔
1060
                    old_nodes[id(n)] = n
6✔
1061

1062
        if found != allowed and not ignore_missing:
6✔
1063
            msg = f"edges {allowed - found} not found in tree"
6✔
1064
            raise ValueError(msg)
6✔
1065

1066
        # make new nodes and also map old id's to new id's
1067
        make_node = self.__class__
6✔
1068
        self_2_new: dict[int, int] = {}
6✔
1069
        new_nodes: dict[int, Self] = {}
6✔
1070
        for self_id, old_node in old_nodes.items():
6✔
1071
            new_node = make_node(
6✔
1072
                old_node.name,
1073
                params=old_node.params.copy(),
1074
                length=old_node.length,
1075
                support=old_node.support,
1076
            )
1077
            new_nodes[id(new_node)] = new_node
6✔
1078
            self_2_new[self_id] = id(new_node)
6✔
1079

1080
        # connect the nodes
1081
        for self_id, old_node in old_nodes.items():
6✔
1082
            if old_node.parent is None:
6✔
1083
                continue
6✔
1084

1085
            new_node = new_nodes[self_2_new[self_id]]
6✔
1086
            new_parent_id = self_2_new[id(old_node.parent)]
6✔
1087
            # the following assignment also adds the new_node as
1088
            # a child to parent
1089
            new_node.parent = new_nodes[new_parent_id]
6✔
1090

1091
        result_root = new_nodes[self_2_new[id(self)]]
6✔
1092
        result_root.prune()
6✔
1093
        if as_rooted or len(self.children) == len(result_root.children):
6✔
1094
            result_root.name = "root"
6✔
1095
            return result_root
6✔
1096

1097
        if len(self.children) > 2:
6✔
1098
            result_root = result_root.unrooted()
6✔
1099
        else:
1100
            # we pick an arbitrary child to root at
1101
            child = result_root.children[0]
×
1102
            child.name = (
×
1103
                child.name if child.name else "new-root"
1104
            )  # this is a magic value, which is not good
1105
            result_root = result_root.rooted(child.name)
×
1106
        result_root.name = "root"
6✔
1107
        return result_root
6✔
1108

1109
    def _edgecount(self, parent: Self, cache: dict[tuple[int, int], int]) -> int:
6✔
1110
        """The number of edges beyond 'parent' in the direction of 'self',
1111
        unrooted"""
1112
        neighbours = self.get_neighbours_except(parent)
6✔
1113
        key = (id(parent), id(self))
6✔
1114
        if key not in cache:
6✔
1115
            cache[key] = 1 + sum(
6✔
1116
                [child._edgecount(self, cache) for child in neighbours],
1117
            )
1118
        return cache[key]
6✔
1119

1120
    def _imbalance(
6✔
1121
        self, parent: Self | None, cache: dict[tuple[int, int], int]
1122
    ) -> tuple[int, int, Self]:
1123
        """The edge count from here, (except via 'parent'), divided into that
1124
        from the heaviest neighbour, and that from the rest of them.  'cache'
1125
        should be a dictionary that can be shared by calls to self.edgecount,
1126
        it stores the edgecount for each node (from self) without having to
1127
        put it on the tree itself."""
1128
        max_weight = 0
6✔
1129
        total_weight = 0
6✔
1130
        biggest_branch = self
6✔
1131
        for child in self.get_neighbours_except(parent):
6✔
1132
            weight = child._edgecount(self, cache)
6✔
1133
            total_weight += weight
6✔
1134
            if weight > max_weight:
6✔
1135
                max_weight = weight
6✔
1136
                biggest_branch = child
6✔
1137
        return max_weight, total_weight - max_weight, biggest_branch
6✔
1138

1139
    def sorted(self, sort_order: list[str] | None = None) -> Self:
6✔
1140
        """An equivalent tree with tips in sort_order.
1141

1142
        Notes
1143
        -----
1144
        If sort_order is not specified then alphabetical order is used.
1145
        At each node starting from root, the algorithm will try to put
1146
        the descendant which contains the smallest index tip on the left.
1147
        """
1148
        sort_order = sort_order or []
6✔
1149
        tip_names = self.get_tip_names()
6✔
1150
        tip_names.sort()
6✔
1151
        full_sort_order = sort_order + tip_names
6✔
1152
        score_map = {name: i for i, name in enumerate(full_sort_order)}
6✔
1153

1154
        constructor = self._default_tree_constructor()
6✔
1155

1156
        scores: dict[PhyloNode, int | None] = {}
6✔
1157
        rebuilt: dict[Self, Self] = {}
6✔
1158

1159
        infinity = float("inf")
6✔
1160
        for node in self.postorder():
6✔
1161
            if node.is_tip():
6✔
1162
                score: int | None = score_map[node.name]
6✔
1163
                tree = node.deepcopy()
6✔
1164
            else:
1165
                child_info = [(scores[ch], rebuilt[ch]) for ch in node.children]
6✔
1166
                # Sort children by score, None is treated as +infinity
1167
                child_info.sort(key=lambda x: (infinity if x[0] is None else x[0]))
6✔
1168
                children = tuple(child for _, child in child_info)
6✔
1169
                tree = constructor(node, children, None)
6✔
1170
                non_null = [s for s, _ in child_info if s is not None]
6✔
1171
                score = non_null[0] if non_null else None
6✔
1172
            scores[node] = score
6✔
1173
            rebuilt[node] = tree
6✔
1174

1175
        return rebuilt[self]
6✔
1176

1177
    def ladderise(self) -> Self:
6✔
1178
        """Return an equivalent tree nodes using a ladderise sort.
1179

1180
        Notes
1181
        -----
1182
        Children are ordered by their number of descendant tips
1183
        with ties broken by alphabetical sort of node names.
1184
        """
1185
        num_tips = {}
6✔
1186
        ordered_names_map = {}
6✔
1187

1188
        for node in self.postorder():
6✔
1189
            if node.is_tip():
6✔
1190
                num_tips[node] = 1
6✔
1191
                ordered_names_map[node] = [node.name]
6✔
1192
            else:
1193
                ordered_kids = sorted(
6✔
1194
                    node.children, key=lambda c: (num_tips[c], ordered_names_map[c][0])
1195
                )
1196

1197
                num_tips[node] = sum(num_tips[k] for k in ordered_kids)
6✔
1198

1199
                names = []
6✔
1200
                for child in ordered_kids:
6✔
1201
                    names.extend(ordered_names_map[child])
6✔
1202
                ordered_names_map[node] = names
6✔
1203

1204
        ordered_names = ordered_names_map[self]
6✔
1205
        return self.sorted(sort_order=ordered_names)
6✔
1206

1207
    ladderize = ladderise  # a synonym with US spelling
6✔
1208

1209
    def _ascii_art(
6✔
1210
        self,
1211
        char1: str = "-",
1212
        show_internal: bool = True,
1213
        compact: bool = False,
1214
    ) -> tuple[list[str], int]:
1215
        length = 10
6✔
1216
        pad = " " * length
6✔
1217
        pa = " " * (length - 1)
6✔
1218
        namestr = self.name
6✔
1219
        if self.children:
6✔
1220
            mids: list[int] = []
6✔
1221
            result: list[str] = []
6✔
1222
            for c in self.children:
6✔
1223
                if c is self.children[0]:
6✔
1224
                    char2 = "/"
6✔
1225
                elif c is self.children[-1]:
6✔
1226
                    char2 = "\\"
6✔
1227
                else:
1228
                    char2 = "-"
×
1229
                clines, mid = c._ascii_art(char2, show_internal, compact)
6✔
1230
                mids.append(mid + len(result))
6✔
1231
                result.extend(clines)
6✔
1232
                if not compact:
6✔
1233
                    result.append("")
6✔
1234
            if not compact:
6✔
1235
                result.pop()
6✔
1236
            lo, hi, end = mids[0], mids[-1], len(result)
6✔
1237
            prefixes = (
6✔
1238
                [pad] * (lo + 1) + [pa + "|"] * (hi - lo - 1) + [pad] * (end - hi)
1239
            )
1240
            mid = (lo + hi) // 2
6✔
1241
            prefixes[mid] = char1 + "-" * (length - 2) + prefixes[mid][-1]
6✔
1242
            result = [pre + res for pre, res in zip(prefixes, result, strict=False)]
6✔
1243
            if show_internal:
6✔
1244
                stem = result[mid]
6✔
1245
                result[mid] = stem[0] + namestr + stem[len(namestr) + 1 :]
6✔
1246
            return result, mid
6✔
1247
        return [char1 + "-" + namestr], 0
6✔
1248

1249
    def ascii_art(self, show_internal: bool = True, compact: bool = False) -> str:
6✔
1250
        """Returns a string containing an ascii drawing of the tree.
1251

1252
        Parameters
1253
        ----------
1254
        show_internal
1255
            includes internal edge names.
1256
        compact
1257
            use exactly one line per tip.
1258

1259
        """
1260
        lines, _ = self._ascii_art(show_internal=show_internal, compact=compact)
6✔
1261
        return "\n".join(lines)
6✔
1262

1263
    def write(
6✔
1264
        self,
1265
        filename: str | os.PathLike[str],
1266
        with_distances: bool = True,
1267
        format_name: str | None = None,
1268
    ) -> None:
1269
        """Save the tree to filename
1270

1271
        Parameters
1272
        ----------
1273
        filename
1274
            path to write the tree to.
1275
        with_distances
1276
            whether branch lengths are included in string.
1277
        format_name
1278
            default is newick, json is alternate. Argument overrides
1279
            the filename suffix. All attributes are saved in the xml format.
1280
            Value overrides the file name suffix.
1281

1282
        Notes
1283
        -----
1284
        Only the cogent3 json and newick tree formats are supported.
1285

1286
        """
1287
        file_format, _ = get_format_suffixes(filename)
6✔
1288
        format_name = format_name or file_format
6✔
1289
        if format_name == "json":
6✔
1290
            with atomic_write(filename, mode="wt") as f:
6✔
1291
                f.write(self.to_json())
6✔
1292
            return
6✔
1293

1294
        data = self.get_newick(with_distances=with_distances)
6✔
1295

1296
        with atomic_write(filename, mode="wt") as outf:
6✔
1297
            outf.writelines(data)
6✔
1298

1299
    def get_node_names(
6✔
1300
        self, include_self: bool = True, tips_only: bool = False
1301
    ) -> list[str]:
1302
        """Return a list of edges from this edge - may or may not include self.
1303
        This node (or first connection) will be the first, and then they will
1304
        be listed in the natural traverse order.
1305

1306
        Parameters
1307
        ----------
1308
        include_self : bool
1309
            excludes self.name from the result
1310

1311
        tips_only : bool
1312
            only tips returned
1313
        """
1314
        if tips_only:
6✔
1315
            nodes = self.tips(include_self=include_self)
6✔
1316
        else:
1317
            nodes = list(self.preorder(include_self=include_self))
6✔
1318
        return [node.name for node in nodes]
6✔
1319

1320
    def get_tip_names(self, include_self: bool = True) -> list[str]:
6✔
1321
        """return the list of the names of all tips contained by this edge"""
1322
        node_names = self.get_node_names(include_self=include_self, tips_only=True)
6✔
1323

1324
        if "" in node_names:
6✔
1325
            msg = "Tree contains unnamed nodes."
×
1326
            raise TreeError(msg)
×
1327

1328
        return node_names
6✔
1329

1330
    def get_edge_vector(self, include_root: bool = True) -> list[Self]:
6✔
1331
        """Collect the list of edges in postfix order
1332

1333
        Parameters
1334
        ----------
1335
        include_root
1336
            specifies whether root edge included
1337

1338
        """
1339
        return list(self.postorder(include_self=include_root))
6✔
1340

1341
    def get_node_matching_name(self, name: str) -> Self:
6✔
1342
        """find the edge with the name
1343

1344
        Raises
1345
        ------
1346
        TreeError if no edge with the name is found
1347
        """
1348
        for node in self.preorder(include_self=True):
6✔
1349
            if node.name == name:
6✔
1350
                break
6✔
1351
        else:
1352
            msg = f"No node named '{name}' in {self.get_tip_names()}"
×
1353
            raise TreeError(msg)
×
1354
        return node
6✔
1355

1356
    def get_connecting_node(self, name1: str, name2: str) -> Self:
6✔
1357
        """Finds the last common ancestor of the two named edges."""
1358
        edge1 = self.get_node_matching_name(name1)
6✔
1359
        edge2 = self.get_node_matching_name(name2)
6✔
1360
        return edge1.last_common_ancestor(edge2)
6✔
1361

1362
    def get_connecting_edges(self, name1: str, name2: str) -> list[Self]:
6✔
1363
        """returns a list of edges connecting two nodes.
1364

1365
        If both are tips, the LCA is excluded from the result."""
1366
        edge1 = self.get_node_matching_name(name1)
6✔
1367
        edge2 = self.get_node_matching_name(name2)
6✔
1368
        include_parent = not (edge1.istip() and edge2.istip())
6✔
1369

1370
        lca = self.get_connecting_node(name1, name2)
6✔
1371
        node_path = [edge1]
6✔
1372
        node_path.extend(edge1.ancestors())
6✔
1373
        # remove nodes deeper than the LCA
1374
        lca_ind = node_path.index(lca)
6✔
1375
        node_path = node_path[: lca_ind + 1]
6✔
1376
        # remove LCA and deeper nodes from anc list of other
1377
        anc2 = edge2.ancestors()
6✔
1378
        lca_ind = anc2.index(lca)
6✔
1379
        anc2 = anc2[:lca_ind]
6✔
1380
        anc2.reverse()
6✔
1381
        node_path.extend(anc2)
6✔
1382
        node_path.append(edge2)
6✔
1383
        if not include_parent:
6✔
1384
            node_path.remove(lca)
6✔
1385
        return node_path
6✔
1386

1387
    def get_param_value(self, param: str, edge: str) -> Any:  # noqa: ANN401
6✔
1388
        """returns the parameter value for named edge"""
1389
        return self.get_node_matching_name(edge).params[param]
6✔
1390

1391
    def set_param_value(self, param: str, edge: str, value: Any) -> None:  # noqa: ANN401
6✔
1392
        """set's the value for param at named edge"""
1393
        self.get_node_matching_name(edge).params[param] = value
6✔
1394

1395
    def reassign_names(
6✔
1396
        self, mapping: dict[str, str], nodes: list[Self] | None = None
1397
    ) -> None:
1398
        """Reassigns node names based on a mapping dict
1399

1400
        mapping : dict, old_name -> new_name
1401
        nodes : specific nodes for renaming (such as just tips, etc...)
1402
        """
1403
        if nodes is None:
6✔
1404
            nodes = list(self.preorder())
6✔
1405

1406
        for n in nodes:
6✔
1407
            if n.name in mapping:
6✔
1408
                n.name = mapping[n.name]
6✔
1409

1410
    def multifurcating(
6✔
1411
        self,
1412
        num: int,
1413
        eps: float | None = None,
1414
        name_unnamed: bool = False,
1415
    ) -> Self:
1416
        """return a new tree with every node having num or few children
1417

1418
        Parameters
1419
        ----------
1420
        num : int
1421
            the number of children a node can have max
1422
        eps : float
1423
            default branch length to set if self or constructor is of
1424
            PhyloNode type
1425
            a PhyloNode or subclass constructor. If None, uses self
1426
        name_unnamed : bool
1427
            names unnamed nodes
1428
        """
1429
        if num < 2:
6✔
1430
            msg = "Minimum number of children must be >= 2"
6✔
1431
            raise TreeError(msg)
6✔
1432

1433
        if eps is None:
6✔
1434
            eps = 0.0
6✔
1435

1436
        constructor = self.__class__
6✔
1437

1438
        new_tree = self.copy()
6✔
1439

1440
        for n in new_tree.preorder(include_self=True):
6✔
1441
            while len(n.children) > num:
6✔
1442
                new_node = constructor("", children=n.children[-num:])
6✔
1443

1444
                if new_node[0].length is not None:
6✔
1445
                    new_node.length = eps
6✔
1446

1447
                n.append(new_node)
6✔
1448

1449
        if name_unnamed:
6✔
1450
            alpha = "abcdefghijklmnopqrstuvwxyz"
6✔
1451
            alpha += alpha.upper()
6✔
1452
            base = "AUTOGENERATED_NAME_%s"
6✔
1453

1454
            # scale the random names by tree size
1455
            s = int(numpy.ceil(numpy.log(len(new_tree.tips()))))
6✔
1456

1457
            for n in new_tree.nontips():
6✔
1458
                if not n.name:
6✔
1459
                    n.name = base % "".join([random.choice(alpha) for _ in range(s)])
6✔
1460

1461
        return new_tree
6✔
1462

1463
    def bifurcating(
6✔
1464
        self,
1465
        eps: float | None = None,
1466
        name_unnamed: bool = False,
1467
    ) -> Self:
1468
        """Wrap multifurcating with a num of 2"""
1469
        return self.multifurcating(2, eps, name_unnamed)
6✔
1470

1471
    def get_nodes_dict(self) -> dict[str, Self]:
6✔
1472
        """Returns a dict keyed by node name, value is node
1473

1474
        Will raise TreeError if non-unique names are encountered
1475
        """
1476
        res: dict[str, Self] = {}
6✔
1477

1478
        for n in self.preorder():
6✔
1479
            if n.name in res:
6✔
1480
                msg = "get_nodes_dict requires unique node names"
6✔
1481
                raise TreeError(msg)
6✔
1482
            res[n.name] = n
6✔
1483

1484
        return res
6✔
1485

1486
    def subset(self) -> frozenset[str]:
6✔
1487
        """Returns set of names that descend from specified node"""
1488
        return frozenset(self.get_tip_names(include_self=False))
6✔
1489

1490
    def subsets(self) -> frozenset[frozenset[str]]:
6✔
1491
        """Returns all sets of names that come from specified node and its kids"""
1492
        sets: list[frozenset[str]] = []
6✔
1493
        for node in self.postorder(include_self=False):
6✔
1494
            if not node.children:
6✔
1495
                node.params["leaf_set"] = frozenset([node.name])
6✔
1496
            else:
1497
                leaf_set: frozenset[str] = reduce(
6✔
1498
                    or_, [c.params.pop("leaf_set") for c in node.children]
1499
                )
1500
                if len(leaf_set) > 1:
6✔
1501
                    sets.append(leaf_set)
6✔
1502
                node.params["leaf_set"] = leaf_set
6✔
1503

1504
        # clean up params entry in children of self
1505
        for child in self.children:
6✔
1506
            child.params.pop("leaf_set", None)
6✔
1507
        return frozenset(sets)
6✔
1508

1509
    def compare_by_subsets(
6✔
1510
        self, other: Self, exclude_absent_taxa: bool = False
1511
    ) -> float:
1512
        """Returns fraction of overlapping subsets where self and other differ.
1513

1514
        Other is expected to be a tree object compatible with PhyloNode.
1515

1516
        Note: names present in only one of the two trees will count as
1517
        mismatches: if you don't want this behavior, strip out the non-matching
1518
        tips first.
1519
        """
1520
        self_sets, other_sets = self.subsets(), other.subsets()
6✔
1521
        if exclude_absent_taxa:
6✔
1522
            in_both = self.subset() & other.subset()
6✔
1523
            self_sets = frozenset(
6✔
1524
                intersection for i in self_sets if len(intersection := i & in_both) > 1
1525
            )
1526
            other_sets = frozenset(
6✔
1527
                intersection for i in other_sets if len(intersection := i & in_both) > 1
1528
            )
1529
        total_subsets = len(self_sets) + len(other_sets)
6✔
1530
        intersection_length = len(self_sets & other_sets)
6✔
1531
        if not total_subsets:  # no common subsets after filtering, so max dist
6✔
1532
            return 1
6✔
1533
        return 1 - 2 * intersection_length / float(total_subsets)
6✔
1534

1535
    def tip_to_tip_distances(
6✔
1536
        self, names: PySeqStr | None = None, default_length: float | None = None
1537
    ) -> DistanceMatrix:
1538
        """Returns distance matrix between all pairs of tips, and a tip order"""
1539
        from cogent3.evolve.fast_distance import DistanceMatrix
6✔
1540

1541
        if names is not None:
6✔
1542
            subtree = self.get_sub_tree(names)
6✔
1543
            return subtree.tip_to_tip_distances(
6✔
1544
                default_length=default_length,
1545
            )
1546
        default_length = (
6✔
1547
            1 if all(node.length is None for node in self.preorder()) else 0
1548
        )
1549
        tips = list(self.tips())
6✔
1550

1551
        # For each tip, build path to root with cumulative distances
1552
        paths: dict[
6✔
1553
            str, list[tuple[Self, float]]
1554
        ] = {}  # tip name -> list of (node, cumulative distance)
1555
        for tip in tips:
6✔
1556
            path: list[tuple[Self, float]] = []
6✔
1557
            current: Self | None = tip  # type: ignore[assignment]
6✔
1558

1559
            dist = 0.0
6✔
1560
            while current is not None:
6✔
1561
                path.append((current, dist))
6✔
1562
                length = (
6✔
1563
                    current.length if current.length is not None else default_length
1564
                )
1565
                dist += length
6✔
1566
                current = current.parent
6✔
1567
            paths[tip.name] = path  # path from tip to root
6✔
1568

1569
        num_tips = len(tips)
6✔
1570
        dists = numpy.zeros((num_tips, num_tips), float)
6✔
1571
        for i, j in combinations(range(num_tips), 2):
6✔
1572
            tip1 = tips[i]
6✔
1573
            tip2 = tips[j]
6✔
1574
            path1 = {id(node): (node, dist) for node, dist in paths[tip1.name]}
6✔
1575
            path2 = {id(node): (node, dist) for node, dist in paths[tip2.name]}
6✔
1576
            common = path1.keys() & path2.keys()
6✔
1577

1578
            if not common:
6✔
1579
                msg = f"No common ancestor for {tip1.name} and {tip2.name}"
×
1580
                raise ValueError(msg)
×
1581

1582
            # Find least common ancestor (node with max total depth)
1583
            lca = min(common, key=lambda n: path1[n][1])
6✔
1584
            total_dist = path1[lca][1] + path2[lca][1]
6✔
1585
            dists[i, j] = dists[j, i] = total_dist
6✔
1586

1587
        return DistanceMatrix.from_array_names(dists, self.get_tip_names())
6✔
1588

1589
    def get_figure(
6✔
1590
        self,
1591
        style: Literal["square", "circular", "angular", "radial"] = "square",
1592
        **kwargs: Any,
1593
    ) -> Dendrogram:
1594
        """
1595
        gets Dendrogram for plotting the phylogeny
1596

1597
        Parameters
1598
        ----------
1599
        style : string
1600
            'square', 'angular', 'radial' or 'circular'
1601
        kwargs
1602
            arguments passed to Dendrogram constructor
1603
        """
1604
        from cogent3.draw.dendrogram import Dendrogram
6✔
1605

1606
        style = cast(
6✔
1607
            "Literal['square', 'circular', 'angular', 'radial']", style.lower()
1608
        )
1609
        types = ("square", "circular", "angular", "radial")
6✔
1610
        if style not in types:
6✔
1611
            msg = f"{style} not in supported types {types}"
×
1612
            raise ValueError(msg)
×
1613

1614
        return Dendrogram(self, style=style, **kwargs)
6✔
1615

1616
    def balanced(self) -> Self:
6✔
1617
        """Tree 'rooted' here with no neighbour having > 50% of the edges.
1618

1619
        Usage:
1620
            Using a balanced tree can substantially improve performance of
1621
            the likelihood calculations. Note that the resulting tree has a
1622
            different orientation with the effect that specifying clades or
1623
            stems for model parameterisation should be done using the
1624
            'outgroup_name' argument.
1625
        """
1626
        # this should work OK on ordinary 3-way trees, not so sure about
1627
        # other cases.  Given 3 neighbours, if one has > 50% of edges it
1628
        # can only improve things to divide it up, worst case:
1629
        # (51),25,24 -> (50,1),49.
1630
        # If no neighbour has >50% we can't improve on where we are, eg:
1631
        # (49),25,26 -> (20,19),51
1632
        last_edge = None
6✔
1633
        edge = self
6✔
1634
        known_weight = 0
6✔
1635
        cache: dict[tuple[int, int], int] = {}
6✔
1636
        while 1:
6✔
1637
            max_weight, remaining_weight, next_edge = edge._imbalance(
6✔
1638
                last_edge,
1639
                cache,
1640
            )
1641
            known_weight += remaining_weight
6✔
1642
            if max_weight <= known_weight + 2:
6✔
1643
                break
6✔
1644
            last_edge = edge
6✔
1645
            edge = next_edge
6✔
1646
            known_weight += 1
6✔
1647
        return edge.unrooted_deepcopy()
6✔
1648

1649
    def same_topology(self, other: Self) -> bool:
6✔
1650
        """Tests whether two trees have the same topology."""
1651
        tip_names = self.get_tip_names()
6✔
1652
        root_at = tip_names[0]
6✔
1653
        me = self.rooted(root_at).sorted(tip_names)
6✔
1654
        them = other.rooted(root_at).sorted(tip_names)
6✔
1655
        return self is other or me.same_shape(them)
6✔
1656

1657
    def unrooted_deepcopy(
6✔
1658
        self,
1659
        parent: Self | None = None,
1660
    ) -> Self:
1661
        """
1662
        Returns a deepcopy of the tree using unrooted traversal.
1663

1664
        Each node is treated as connected to its parent and children.
1665
        The resulting tree may contain unary internal nodes, which can
1666
        be cleaned up using `prune()` afterward.
1667
        """
1668
        constructor = self._default_tree_constructor()
6✔
1669

1670
        # node_map maps id(original_node) -> new_node
1671
        node_map: dict[int, Self] = {}
6✔
1672
        # stack is last in first out
1673
        # stack stores (original_node, parent_we_came_from, state)
1674
        # False state is the first visit, discover neighbors
1675
        # True state is the second visit, construct new node
1676
        stack = [(self, parent, False)]
6✔
1677
        while stack:
6✔
1678
            node, parent_node, state = stack.pop()
6✔
1679

1680
            if not state:
6✔
1681
                # put the node, and then it's children on the stack
1682
                stack.append((node, parent_node, True))
6✔
1683
                stack.extend(
6✔
1684
                    (neigh, node, False)
1685
                    for neigh in node.get_neighbours_except(parent_node)
1686
                )
1687
            else:
1688
                # children are created and in node_map prior to their parents
1689
                # being visited
1690
                children = [
6✔
1691
                    node_map[id(neigh)]
1692
                    for neigh in node.get_neighbours_except(parent_node)
1693
                ]
1694

1695
                if parent_node is None:
6✔
1696
                    edge = None
6✔
1697
                elif parent_node.parent is node:
6✔
1698
                    edge = parent_node
6✔
1699
                else:
1700
                    edge = node
6✔
1701

1702
                new_node = constructor(edge, tuple(children), None)
6✔
1703
                node_map[id(node)] = new_node
6✔
1704
                if parent_node is None:
6✔
1705
                    new_node.name = "root"
6✔
1706

1707
        new_root = node_map[id(self)]
6✔
1708
        new_root.prune(keep_root=True)
6✔
1709
        return new_root
6✔
1710

1711
    def unrooted(self) -> Self:
6✔
1712
        """A tree with at least 3 children at the root."""
1713
        constructor = self._default_tree_constructor()
6✔
1714
        need_to_expand = len(self.children) < 3
6✔
1715
        new_children: list[Self] = []
6✔
1716
        for oldnode in self.children:
6✔
1717
            if oldnode.children and need_to_expand:
6✔
1718
                for sib in oldnode.children:
6✔
1719
                    new_sib = sib.deepcopy()
6✔
1720
                    if new_sib.length is not None and oldnode.length is not None:
6✔
1721
                        new_sib.length += oldnode.length
6✔
1722
                    new_children.append(new_sib)
6✔
1723
                need_to_expand = False
6✔
1724
            else:
1725
                new_children.append(oldnode.deepcopy())
6✔
1726
        return constructor(self, new_children, None)
6✔
1727

1728
    def rooted_at(self, edge_name: str) -> Self:
6✔
1729
        """Return a new tree rooted at the provided node.
1730

1731
        Usage:
1732
            This can be useful for drawing unrooted trees with an orientation
1733
            that reflects knowledge of the true root location.
1734
        """
1735
        newroot = self.get_node_matching_name(edge_name)
×
1736
        if not newroot.children:
×
1737
            msg = f"Can't use a tip ({edge_name!r}) as the root"
×
1738
            raise TreeError(msg)
×
1739
        return newroot.unrooted_deepcopy()
×
1740

1741
    def rooted_with_tip(self, outgroup_name: str) -> Self:
6✔
1742
        """A new tree with the named tip as one of the root's children"""
1743
        tip = self.get_node_matching_name(outgroup_name)
6✔
1744
        parent = cast("Self", tip.parent)
6✔
1745
        return parent.unrooted_deepcopy()
6✔
1746

1747
    def tree_distance(self, other: PhyloNode, method: str | None = None) -> int:
6✔
1748
        """Return the specified tree distance between this and another tree.
1749

1750
        Defaults to the Lin-Rajan-Moret distance on unrooted trees.
1751
        Defaults to the Matching Cluster distance on rooted trees.
1752

1753
        Parameters
1754
        ----------
1755
        other: PhyloNode
1756
            The other tree to calculate the distance between.
1757
        method: str | None
1758
            The tree distance metric to use.
1759

1760
            Options are:
1761
            "rooted_robinson_foulds": The Robinson-Foulds distance for rooted trees.
1762
            "unrooted_robinson_foulds": The Robinson-Foulds distance for unrooted trees.
1763
            "matching_cluster": The Matching Cluster distance for rooted trees.
1764
            "lin_rajan_moret": The Lin-Rajan-Moret distance for unrooted trees.
1765
            "rrf": An alias for rooted_robinson_foulds.
1766
            "urf": An alias for unrooted_robinson_foulds.
1767
            "mc": An alias for matching_cluster.
1768
            "lrm": An alias for lin_rajan_moret.
1769
            "rf": The unrooted/rooted Robinson-Foulds distance for unrooted/rooted trees.
1770
            "matching": The Lin-Rajan-Moret/Matching Cluster distance for unrooted/rooted trees.
1771

1772
            Default is "matching".
1773

1774
        Returns
1775
        -------
1776
        int
1777
            the chosen distance between the two trees.
1778

1779
        Notes
1780
        -----
1781
        The Lin-Rajan-Moret distance [2]_ and Matching Cluster distance [1]_
1782
        display superior statistical properties than the Robinson-Foulds
1783
        distance [3]_ on unrooted and rooted trees respectively.
1784

1785
        References
1786
        ----------
1787
        .. [1] Bogdanowicz, D., & Giaro, K. (2013).
1788
           On a matching distance between rooted phylogenetic trees.
1789
           International Journal of Applied Mathematics and Computer Science, 23(3), 669-684.
1790
        .. [2] Lin et al. 2012
1791
           A Metric for Phylogenetic Trees Based on Matching
1792
           IEEE/ACM Transactions on Computational Biology and Bioinformatics
1793
           vol. 9, no. 4, pp. 1014-1022, July-Aug. 2012
1794
        .. [3] Robinson, David F., and Leslie R. Foulds.
1795
           Comparison of phylogenetic trees.
1796
           Mathematical biosciences 53.1-2 (1981): 131-147.
1797
        """
1798

1799
        if method is None:
6✔
1800
            method = "matching"
6✔
1801

1802
        is_rooted = len(self) == 2
6✔
1803
        if (is_rooted and len(other) != 2) or (not is_rooted and len(other) == 2):
6✔
1804
            msg = "Both trees must be rooted or both trees must be unrooted."
6✔
1805
            raise ValueError(
6✔
1806
                msg,
1807
            )
1808

1809
        return get_tree_distance_measure(method, is_rooted)(self, other)
6✔
1810

1811
    def lin_rajan_moret(self, tree2: Self) -> int:
6✔
1812
        """return the lin-rajan-moret distance between trees
1813

1814
        float
1815
            the Lin-Rajan-Moret distance
1816

1817
        Notes
1818
        -----
1819
        This is a distance measure that exhibits superior statistical
1820
        properties compared to Robinson-Foulds. It can only be applied to
1821
        unrooted trees.
1822

1823
        see: Lin et al. 2012
1824
        A Metric for Phylogenetic Trees Based on Matching
1825
        IEEE/ACM Transactions on Computational Biology and Bioinformatics
1826
        vol. 9, no. 4, pp. 1014-1022, July-Aug. 2012
1827
        """
1828
        from cogent3.phylo.tree_distance import lin_rajan_moret
6✔
1829

1830
        return lin_rajan_moret(self, tree2)
6✔
1831

1832
    def child_parent_map(self) -> dict[str, str]:
6✔
1833
        """return dict of {<child name>: <parent name>, ...}"""
1834
        return {
6✔
1835
            e.name: cast("Self", e.parent).name
1836
            for e in self.postorder(include_self=False)
1837
        }
1838

1839
    def distance(self, other: Self) -> float:
6✔
1840
        """Returns branch length between self and other."""
1841
        # never any length between self and other
1842
        if self is other:
6✔
1843
            return 0
6✔
1844
        # otherwise, find self's ancestors and find the first ancestor of
1845
        # other that is in the list
1846
        self_anc = self.ancestors()
6✔
1847
        self_anc_dict = {id(n): n for n in self_anc}
6✔
1848
        self_anc_dict[id(self)] = self
6✔
1849

1850
        count = 0.0
6✔
1851

1852
        other_chain: Self | None = other
6✔
1853
        while other_chain is not None:
6✔
1854
            if id(other_chain) in self_anc_dict:
6✔
1855
                # found the first shared ancestor -- need to sum other branch
1856
                curr = self
6✔
1857
                while curr is not other_chain:
6✔
1858
                    if curr.length:
6✔
1859
                        count += curr.length
6✔
1860
                    curr = cast("Self", curr._parent)
6✔
1861
                return count
6✔
1862
            if other_chain.length:
6✔
1863
                count += other_chain.length
6✔
1864
            other_chain = other_chain._parent
6✔
1865
        msg = "The other node is not in the same tree."
×
1866
        raise TreeError(msg)
×
1867

1868
    def total_descending_branch_length(self) -> float:
6✔
1869
        """Returns total descending branch length from self"""
1870
        return sum(
6✔
1871
            n.length for n in self.preorder(include_self=False) if n.length is not None
1872
        )
1873

1874
    def total_length(self) -> float:
6✔
1875
        """returns the sum of all branch lengths in tree"""
1876
        root = self.get_root()
6✔
1877
        return root.total_descending_branch_length()
6✔
1878

1879
    def tips_within_distance(self, distance: float) -> list[Self]:
6✔
1880
        """Returns tips within specified distance from self
1881

1882
        Branch lengths of None will be interpreted as 0
1883
        """
1884

1885
        def get_distance(d1: float, d2: float | None) -> float:
6✔
1886
            if d2 is None:
6✔
1887
                return d1
6✔
1888
            return d1 + d2
6✔
1889

1890
        to_process = [(self, 0.0)]
6✔
1891
        tips_to_save: list[Self] = []
6✔
1892

1893
        seen = {id(self)}
6✔
1894
        while to_process:
6✔
1895
            curr_node, curr_dist = to_process.pop(0)
6✔
1896

1897
            # have we've found a tip within distance?
1898
            if curr_node.is_tip() and curr_node != self:
6✔
1899
                tips_to_save.append(curr_node)
6✔
1900
                continue
6✔
1901

1902
            # add the parent node if it is within distance
1903
            parent_dist = get_distance(curr_dist, curr_node.length)
6✔
1904
            if (
6✔
1905
                curr_node.parent is not None
1906
                and parent_dist <= distance
1907
                and id(curr_node.parent) not in seen
1908
            ):
1909
                to_process.append((curr_node.parent, parent_dist))
6✔
1910
                seen.add(id(curr_node.parent))
6✔
1911

1912
            # add children if we haven't seen them and if they are in distance
1913
            for child in curr_node.children:
6✔
1914
                if id(child) in seen:
6✔
1915
                    continue
6✔
1916
                seen.add(id(child))
6✔
1917

1918
                child_dist = get_distance(curr_dist, child.length)
6✔
1919
                if child_dist <= distance:
6✔
1920
                    to_process.append((child, child_dist))
6✔
1921

1922
        return tips_to_save
6✔
1923

1924
    def root_at_midpoint(self) -> Self:
6✔
1925
        """return a new tree rooted at midpoint of the two tips farthest apart
1926

1927
        this fn doesn't preserve the internal node naming or structure,
1928
        but does keep tip to tip distances correct.  uses unrooted_deepcopy()
1929
        """
1930
        dmat = self.tip_to_tip_distances()
6✔
1931
        a, b = dmat.max_pair()
6✔
1932
        max_dist = dmat[a, b]
6✔
1933
        if max_dist <= 0.0:
6✔
1934
            msg = f"{max_dist=} must be > 0"
×
1935
            raise TreeError(msg)
×
1936

1937
        mid_point = max_dist / 2.0
6✔
1938
        path_nodes = self.get_connecting_edges(a, b)
6✔
1939
        cumsum = 0.0
6✔
1940
        has_length = any(node.length is not None for node in self.preorder())
6✔
1941
        default_length = 0.0 if has_length else 1.0
6✔
1942

1943
        node = path_nodes[0]
6✔
1944
        for node in path_nodes:
6✔
1945
            length = (node.length or default_length) if has_length else default_length
6✔
1946
            cumsum += length
6✔
1947
            if cumsum >= mid_point:
6✔
1948
                break
6✔
1949

1950
        parent: Self = cast("Self", node.parent)
6✔
1951
        if parent.is_root() and len(parent.children) == 2:
6✔
1952
            # already midpoint rooted, but adjust lengths from root
1953
            _adjust_lengths_from_root(tip_name=a, mid_point=mid_point, tree=self)
6✔
1954
            return self
6✔
1955

1956
        new_tree = self.rooted(node.name)
6✔
1957
        _adjust_lengths_from_root(tip_name=a, mid_point=mid_point, tree=new_tree)
6✔
1958
        return new_tree
6✔
1959

1960
    def get_max_tip_tip_distance(
6✔
1961
        self,
1962
    ) -> tuple[float, tuple[str, str], Self]:
1963
        """Returns the max tip-to-tip distance between any pair of tips
1964

1965
        Returns
1966
        -------
1967
        dist, tip_names, internal_node
1968
        """
1969
        dmat = self.tip_to_tip_distances()
6✔
1970
        a, b = dmat.max_pair()
6✔
1971
        dist = dmat[a, b]
6✔
1972
        return float(dist), (a, b), self.get_connecting_node(a, b)
6✔
1973

1974
    def max_tip_tip_distance(self) -> tuple[float, tuple[str, str]]:
6✔
1975
        """returns the max distance between any pair of tips
1976

1977
        Also returns the tip names  that it is between as a tuple"""
1978
        dist, pair, _ = self.get_max_tip_tip_distance()
6✔
1979
        return dist, pair
6✔
1980

1981
    @staticmethod
6✔
1982
    def parse_token(token: str | None) -> tuple[str | None, float | None]:
6✔
1983
        name, support = split_name_and_support(token)
6✔
1984
        return name, support
6✔
1985

1986
    def tip_to_root_distances(
6✔
1987
        self,
1988
        names: list[str] | None = None,
1989
        default_length: float = 1,
1990
        *,
1991
        node_length: bool = False,
1992
    ) -> dict[str, float]:
1993
        """returns the cumulative sum of lengths from each tip to the root
1994

1995
        Parameters
1996
        ----------
1997
        names
1998
            list of tip names to calculate distances for, defaults to all
1999
        default_length
2000
            value to use for edges that no length value
2001
        """
2002
        tips = self.tips()
6✔
2003
        if names is not None:
6✔
2004
            tips = [t for t in tips if t.name in names]
6✔
2005

2006
        if not tips:
6✔
2007
            msg = f"No tips matching in {names!r}"
6✔
2008
            raise TreeError(msg)
6✔
2009

2010
        dists: dict[str, float] = {}
6✔
2011
        for tip in tips:
6✔
2012
            node = tip
6✔
2013
            cum_sum = 0.0
6✔
2014
            while node.parent is not None:
6✔
2015
                if node_length or node.length is None:
6✔
2016
                    cum_sum += default_length
6✔
2017
                else:
2018
                    cum_sum += node.length
6✔
2019

2020
                node = node.parent
6✔
2021
            dists[tip.name] = cum_sum
6✔
2022
        return dists
6✔
2023

2024
    def renamed_nodes(self, name_map: dict[str, str]) -> Self:
6✔
2025
        """returns a copy of the tree with nodes renamed according to name_map
2026

2027
        Parameters
2028
        ----------
2029
        name_map
2030
            dict of {old_name: new_name, ...}
2031
        """
2032
        new_tree = self.deepcopy()
6✔
2033
        for node in new_tree.preorder():
6✔
2034
            node.name = name_map.get(node.name, node.name)
6✔
2035
        return new_tree
6✔
2036

2037

2038
T = TypeVar("T", bound=PhyloNode)
6✔
2039

2040

2041
def _adjust_lengths_from_root(
6✔
2042
    *, tip_name: str, mid_point: float, tree: PhyloNode
2043
) -> None:
2044
    if len(tree.children) != 2:
6✔
2045
        msg = "root node must have 2 children"
×
2046
        raise TreeError(msg)
×
2047

2048
    to_tip, other = tree.children
6✔
2049
    if tip_name not in to_tip.get_tip_names():
6✔
2050
        to_tip, other = other, to_tip
6✔
2051

2052
    a_to_root = tree.tip_to_root_distances(names=[tip_name])
6✔
2053
    delta = a_to_root[tip_name] - mid_point
6✔
2054

2055
    if to_tip.length is not None:
6✔
2056
        to_tip.length -= delta
6✔
2057
    if other.length is not None:
6✔
2058
        other.length += delta
6✔
2059

2060

2061
def split_name_and_support(name_field: str | None) -> tuple[str | None, float | None]:
6✔
2062
    """Handle cases in the Newick format where an internal node name field
2063
    contains a name or/and support value, like 'edge.98/100'.
2064
    """
2065
    # handle the case where the name field is None or empty string
2066
    if not name_field:
6✔
2067
        return None, None
6✔
2068

2069
    # if name_field is "24", treat it as support, returns (None, 24.0)
2070
    with contextlib.suppress(ValueError):
6✔
2071
        return None, float(name_field)
6✔
2072

2073
    # otherwise, split the name field into name and support
2074
    name, *support = name_field.split("/")
6✔
2075

2076
    if len(support) == 1:
6✔
2077
        try:
6✔
2078
            support_value = float(support[0])
6✔
2079
        except ValueError as e:
6✔
2080
            msg = f"Support value at node: {name!r} should be int or float not {support[0]!r}."
6✔
2081
            raise ValueError(
6✔
2082
                msg,
2083
            ) from e
2084
    # handle case where mutiple '/' in the name field
2085
    elif len(support) > 1:
6✔
2086
        msg = f"Support value at node: {name!r} should be int or float not {'/'.join(support)!r}."
6✔
2087
        raise ValueError(
6✔
2088
            msg,
2089
        )
2090
    else:
2091
        support_value = None
6✔
2092

2093
    return name, support_value
6✔
2094

2095

2096
class TreeBuilder:
6✔
2097
    # Some tree code which isn't needed once the tree is finished.
2098
    # Mostly exists to give edges unique names
2099
    # children must be created before their parents.
2100

2101
    def __init__(self, constructor: type[PhyloNode] = PhyloNode) -> None:
6✔
2102
        self._used_names = {"edge": -1}
6✔
2103
        self.PhyloNodeClass = constructor
6✔
2104

2105
    def _unique_name(self, name: str | None) -> str:
6✔
2106
        # Unnamed edges become edge.0, edge.1 edge.2 ...
2107
        # Other duplicates go mouse mouse.2 mouse.3 ...
2108
        if not name:
6✔
2109
            name = "edge"
6✔
2110
        if name in self._used_names:
6✔
2111
            self._used_names[name] += 1
6✔
2112
            name += f".{self._used_names[name]!s}"
6✔
2113
            # in case of names like 'edge.1.1'
2114
            name = self._unique_name(name)
6✔
2115
        else:
2116
            self._used_names[name] = 1
6✔
2117
        return name
6✔
2118

2119
    def _params_for_edge(self, edge: PhyloNode) -> dict[str, Any]:
6✔
2120
        # default is just to keep it
2121
        return edge.params
6✔
2122

2123
    def edge_from_edge(
6✔
2124
        self,
2125
        edge: PhyloNode | None,
2126
        children: PySeq[PhyloNode],
2127
        params: dict[str, Any] | None = None,
2128
    ) -> PhyloNode:
2129
        """Callback for tree-to-tree transforms like get_sub_tree"""
2130
        if not isinstance(children, list):
6✔
2131
            children = list(children)
6✔
2132
        if edge is None:
6✔
2133
            if params:
6✔
2134
                msg = "No params allowed when edge is None."
×
2135
                raise ValueError(msg)
×
2136
            return self.create_edge(
6✔
2137
                children,
2138
                "root",
2139
                {},
2140
                None,
2141
                None,
2142
                name_loaded=False,
2143
            )
2144
        if params is None:
6✔
2145
            params = self._params_for_edge(edge)
6✔
2146
        return self.create_edge(
6✔
2147
            children,
2148
            edge.name,
2149
            params,
2150
            edge.length,
2151
            edge.support,
2152
            name_loaded=edge.name_loaded,
2153
        )
2154

2155
    def create_edge(
6✔
2156
        self,
2157
        children: PySeq[PhyloNode] | None,
2158
        name: str | None,
2159
        params: dict[str, Any],
2160
        length: float | None,
2161
        support: float | None,
2162
        name_loaded: bool = True,
2163
    ) -> PhyloNode:
2164
        """Callback for newick parser"""
2165
        if children is None:
6✔
2166
            children = []
6✔
2167
        # split name and support for internal nodes
2168
        elif children != []:
6✔
2169
            name, new_support = self.PhyloNodeClass.parse_token(name)
6✔
2170

2171
            if new_support is not None:
6✔
2172
                if support is not None and new_support != support:
6✔
2173
                    msg = f"Got conflicting values for support. In name token '{name}': {new_support}. In constructor {support}."
6✔
2174
                    raise ValueError(msg)
6✔
2175
                support = new_support
6✔
2176

2177
        return self.PhyloNodeClass(
6✔
2178
            name=self._unique_name(name),
2179
            children=list(children),
2180
            name_loaded=name_loaded and (name is not None),
2181
            params=params,
2182
            length=length,
2183
            support=support,
2184
        )
2185

2186

2187
def make_tree(
6✔
2188
    treestring: str | None = None,
2189
    tip_names: list[str] | None = None,
2190
    format_name: str | None = None,
2191
    underscore_unmunge: bool = False,
2192
    source: str | pathlib.Path | None = None,
2193
) -> PhyloNode:
2194
    """Initialises a tree.
2195

2196
    Parameters
2197
    ----------
2198
    treestring
2199
        a newick or xml formatted tree string
2200
    tip_names
2201
        a list of tip names, returns a "star" topology tree
2202
    format_name
2203
        indicates treestring is either newick or xml formatted, default
2204
        is newick
2205
    underscore_unmunge
2206
        replace underscores with spaces in all names read, i.e. "sp_name"
2207
        becomes "sp name"
2208
    source
2209
        path to file tree came from, string value assigned to tree.source
2210

2211
    Notes
2212
    -----
2213
    Underscore unmunging is turned off by default, although it is part
2214
    of the Newick format.
2215

2216
    Returns
2217
    -------
2218
    PhyloNode
2219
    """
2220

2221
    source = str(source) if source else None
6✔
2222
    if tip_names:
6✔
2223
        tree_builder = TreeBuilder().create_edge
6✔
2224
        tips = [
6✔
2225
            tree_builder([], str(tip_name), {}, None, None) for tip_name in tip_names
2226
        ]
2227
        result = tree_builder(tips, "root", {}, None, None)
6✔
2228
        result.source = source
6✔
2229
        return result
6✔
2230

2231
    if not treestring:
6✔
2232
        msg = "Must provide either treestring or tip_names."
×
2233
        raise ValueError(msg)
×
2234

2235
    if format_name is None and treestring.startswith("<"):
6✔
2236
        format_name = "xml"
×
2237

2238
    tree_builder = TreeBuilder().create_edge
6✔
2239
    # FIXME: More general strategy for underscore_unmunge
2240
    tree = newick_parse_string(
6✔
2241
        treestring, tree_builder, underscore_unmunge=underscore_unmunge
2242
    )
2243
    if not tree.name_loaded:
6✔
2244
        tree.name = "root"
6✔
2245

2246
    tree.source = source
6✔
2247
    return tree
6✔
2248

2249

2250
def load_tree(
6✔
2251
    filename: str | pathlib.Path,
2252
    format_name: str | None = None,
2253
    underscore_unmunge: bool = False,
2254
) -> PhyloNode:
2255
    """Constructor for tree.
2256

2257
    Parameters
2258
    ----------
2259
    filename
2260
        a file path containing a newick or xml formatted tree.
2261
    format_name
2262
        either xml or json, all other values default to newick. Overrides
2263
        file name suffix.
2264
    underscore_unmunge
2265
        replace underscores with spaces in all names read, i.e. "sp_name"
2266
        becomes "sp name".
2267

2268
    Notes
2269
    -----
2270
    Underscore unmunging is turned off by default, although it is part
2271
    of the Newick format. Only the cogent3 json and xml tree formats are
2272
    supported.
2273

2274
    filename is assigned to root node tree.source attribute.
2275

2276
    Returns
2277
    -------
2278
    PhyloNode
2279
    """
2280
    fmt, _ = get_format_suffixes(filename)
6✔
2281
    format_name = format_name or fmt
6✔
2282
    if format_name == "json":
6✔
2283
        tree = load_from_json(filename, (PhyloNode,))
6✔
2284
        tree.source = str(filename)
6✔
2285
        return tree
6✔
2286

2287
    with open_(filename) as tfile:
6✔
2288
        treestring = tfile.read()
6✔
2289

2290
    return make_tree(
6✔
2291
        treestring,
2292
        format_name=format_name,
2293
        underscore_unmunge=underscore_unmunge,
2294
        source=filename,
2295
    )
2296

2297

2298
@register_deserialiser("cogent3.core.tree")
6✔
2299
def deserialise_tree(
6✔
2300
    data: dict[str, Any],
2301
) -> PhyloNode:
2302
    """returns a cogent3 PhyloNode instance"""
2303
    # we load tree using make_tree, then populate edge attributes
2304
    edge_attr = cast("dict[str, dict[str, Any]]", data["edge_attributes"])
6✔
2305
    length_and_support: dict[str, dict[str, float | None]] | None = cast(
6✔
2306
        "dict[str, dict[str, float | None]] | None", data.get("length_and_support")
2307
    )
2308

2309
    if length_and_support is None:
6✔
2310
        warnings.warn(
6✔
2311
            "Outdated tree json. Please update by regenerating the json on the loaded tree.",
2312
            stacklevel=3,
2313
        )
2314

2315
    tree = make_tree(treestring=cast("str", data["newick"]))
6✔
2316
    for edge in tree.preorder():
6✔
2317
        params = edge_attr.get(edge.name, {})
6✔
2318
        if length_and_support is None:
6✔
2319
            edge.length = params.pop("length", None)
6✔
2320
            edge.support = params.pop("support", None)
6✔
2321
        else:
2322
            edge.length = length_and_support[edge.name]["length"]
6✔
2323
            edge.support = length_and_support[edge.name]["support"]
6✔
2324
        edge.params.update(params)
6✔
2325
    return tree
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc