• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

js51 / SplitP / 15843786339

24 Jun 2025 07:10AM UTC coverage: 46.006% (+0.4%) from 45.65%
15843786339

push

github

web-flow
Update requirements.txt

455 of 989 relevant lines covered (46.01%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

16.77
/splitp/phylogenetics.py
1
import splitp as sp
2✔
2
import numpy as np
2✔
3
from itertools import combinations
2✔
4
import networkx as nx
2✔
5
from networkx import dfs_postorder_nodes, bfs_successors
2✔
6
from splitp import splits
2✔
7
from splitp.matrix import is_sparse, frobenius_norm
2✔
8
import scipy
2✔
9
import splitp.constants as constants
2✔
10
from math import sqrt
2✔
11

12

13
def parsimony_score(self, pattern):
2✔
14
    """Calculate a parsimony score for a site pattern or split
15

16
    Args:
17
        pattern: A string representing a site pattern or split
18

19
    Returns:
20
        A parsimony score for the given split or site pattern
21
    """
22
    graph = self.nx_graph.copy()
×
23
    nodes = graph.nodes
×
24
    if "|" in pattern:
×
25
        pattern2 = [0 for x in range(len(pattern) - 1)]
×
26
        i = 0
×
27
        while pattern[i] != "|":
×
28
            pattern2[int(pattern[i])] = 1
×
29
            i += 1
×
30
        pattern = "".join(str(i) for i in pattern2)
×
31

32
    taxa = [t for t in self.taxa]
×
33
    for i, t in enumerate(taxa):
×
34
        nodes[t]["pars"] = set(pattern[i])
×
35
    score = 0
×
36
    for n in nodes:
×
37
        if not self.is_leaf(n) and "pars" not in nodes[n]:
×
38
            score += self.__parsimony(n, nodes=nodes)
×
39
    return score
×
40

41

42
def __parsimony(self, n, nodes=None):
2✔
43
    """Recursive step in Finch's Algorithm"""
44
    score = 0
×
45
    children = self.get_descendants(n)
×
46
    for c in children:
×
47
        if "pars" not in nodes[c]:
×
48
            score += self.__parsimony(c, nodes)
×
49

50
    nodes[n]["pars"] = nodes[children[0]]["pars"]  # children[0].pars
×
51
    for c in children:
×
52
        nodes[n]["pars"] = nodes[n]["pars"].intersection(nodes[c]["pars"])
×
53
    if nodes[n]["pars"] == set():
×
54
        nodes[n]["pars"] = nodes[children[0]]["pars"]
×
55
        for c in children:
×
56
            nodes[n]["pars"] = (nodes[n]["pars"]).union(nodes[c]["pars"])
×
57
        return score + 1
×
58
    return score
×
59

60

61
def hartigan_algorithm(self, pattern):
2✔
62
    score = 0
×
63
    graph = self.nx_graph.copy()
×
64
    nodes = graph.nodes
×
65
    taxa = [t for t in self.taxa]
×
66
    for i, t in enumerate(taxa):
×
67
        nodes[t]["S1"] = set(pattern[i])
×
68
    postorder_nodes = list(
×
69
        dfs_postorder_nodes(graph, source=self.get_root(return_index=False))
70
    )
71
    for n in postorder_nodes:
×
72
        if not self.is_leaf(n):
×
73
            children = self.get_descendants(n)
×
74
            k = {}
×
75
            for state in self.state_space:
×
76
                k[state] = len(
×
77
                    set(child for child in children if state in nodes[child]["S1"])
78
                )
79
            K = max(k.values())
×
80
            nodes[n]["S1"] = {state for state in self.state_space if k[state] == K}
×
81
            nodes[n]["S2"] = {state for state in self.state_space if k[state] == K - 1}
×
82
    # Now compute the score
83
    top_to_bottom_nodes = [
×
84
        x[0] for x in bfs_successors(graph, source=self.get_root(return_index=False))
85
    ] + taxa
86
    for n in top_to_bottom_nodes:
×
87
        if n == self.get_root(return_index=False):
×
88
            nodes[n]["hart_state"] = list(nodes[n]["S1"])[0]
×
89
        else:
90
            parent = nodes[list(graph.predecessors(n))[0]]
×
91
            if parent["hart_state"] not in nodes[n]["S1"]:
×
92
                nodes[n]["hart_state"] = list(nodes[n]["S1"])[0]
×
93
                score += 1
×
94
            else:
95
                nodes[n]["hart_state"] = parent["hart_state"]
×
96
    return score
×
97

98

99
def erickson_SVD(alignment, taxa=None, method=sp.Method.flattening, show_work=False):
2✔
100
    all_scores = {}
×
101
    subflattening_data = {}
×
102

103
    def _erickstep(all_taxa, alignment):
×
104
        scores = {}
×
105
        for pair in combinations(taxa, 2):
×
106
            flat_pair = tuple(
×
107
                sorted(
108
                    element
109
                    for tup in pair
110
                    for element in (tup if not isinstance(tup, str) else (tup,))
111
                )
112
            )
113
            other = tuple(
×
114
                sorted(
115
                    element
116
                    for tup in all_taxa
117
                    for element in (tup if not isinstance(tup, str) else (tup,))
118
                    if element not in flat_pair
119
                )
120
            )
121
            split = (flat_pair, other)
×
122
            try:
×
123
                score = all_scores[split]
×
124
            except KeyError:
×
125
                score = np.inf
×
126
                if method == sp.Method.flattening:
×
127
                    flattening = sp.flattening(split, alignment, sp.FlatFormat.reduced)
×
128
                    score = sp.split_score(flattening)
×
129

130
                elif method == sp.Method.subflattening:
×
131
                    subflattening = sp.subflattening(
×
132
                        split, alignment, subflattening_data
133
                    )
134
                    score = sp.split_score(subflattening)
×
135

136
                elif method == sp.Method.mutual_information:
×
137
                    flattening = sp.flattening(split, alignment, sp.FlatFormat.reduced)
×
138
                    score = sp.phylogenetics.flattening_rank_1_approximation_divergence(
×
139
                        flattening
140
                    )
141

142
                all_scores[split] = score
×
143
            scores[pair] = (pair, split, score)
×
144
        if show_work:
×
145
            print(f"Scores: {scores}")
×
146
        best_pair, best_split, best_score = min(scores.values(), key=lambda x: x[2])
×
147
        return best_pair, best_split, best_score
×
148

149
    num_taxa = len(list(alignment.keys())[0])  # Length of first pattern
×
150
    if taxa is None:
×
151
        taxa = [
×
152
            str(np.base_repr(i, base=max(i + 1, 2))) if num_taxa <= 36 else f"t{str(i)}"
153
            for i in range(num_taxa)
154
        ]
155

156
    true_splits = []
×
157
    while len(true_splits) < num_taxa - 2:
×
158
        best_pair, best_split, best_score = _erickstep(taxa, alignment)
×
159
        true_splits.append(tuple(sorted(best_split)))
×
160
        taxa = tuple(
×
161
            [
162
                element
163
                for element in taxa
164
                if (
165
                    element not in best_split[0]
166
                    and not set(element).issubset(best_split[0])
167
                )
168
            ]
169
            + [best_split[0]]
170
        )
171
    return true_splits
×
172

173

174
def newick_string_from_splits(splits):
2✔
175
    def _consolidate(tup, smaller_halves):
2✔
176
        if len(tup) == 2:
2✔
177
            return tup
2✔
178
        if len(tup) == 1:
2✔
179
            return tup[0]
2✔
180
        if isinstance(tup, str):
2✔
181
            return tup
×
182
        for smaller_half in smaller_halves:
2✔
183
            if set(smaller_half).issubset(tup) and len(smaller_half) < len(tup):
2✔
184
                # then the consolidation is made up of the smaller half and what is left over
185
                left_over = tuple(set(tup).difference(set(smaller_half)))
2✔
186
                try:
2✔
187
                    smaller_halves.remove(smaller_half)
2✔
188
                    smaller_halves.remove(left_over)
2✔
189
                except ValueError:
2✔
190
                    pass
2✔
191
                return tuple(
2✔
192
                    (
193
                        _consolidate(left_over, smaller_halves),
194
                        _consolidate(smaller_half, smaller_halves),
195
                    )
196
                )
197

198
    splits = sorted(splits, key=lambda x: min(len(x[0]), len(x[1])), reverse=True)
2✔
199
    if len(splits) == 1:
2✔
200
        return str(splits[0]).replace("'", "").replace(" ", "") + ";"
2✔
201
    if len(splits) == 0:
2✔
202
        return ";"
×
203
    splits = iter(splits)
2✔
204
    first_split = next(splits)
2✔
205
    smaller_halves = [min(split, key=len) for split in splits]
2✔
206
    consolidated_split = tuple(
2✔
207
        _consolidate(half, smaller_halves) for half in first_split
208
    )
209
    return str(consolidated_split).replace("'", "").replace(" ", "") + ";"
2✔
210

211

212
def tree_from_splits(splits):
2✔
213
    return sp.NXTree(newick_string_from_splits(splits))
×
214

215

216
def split_tree_parsimony(alignment, splits=None):
2✔
217
    if type(alignment) is dict:
×
218
        alignment_dict = alignment
×
219
    else:
220
        alignment_dict = {}
×
221
        for table_pattern, value in alignment.itertuples(index=False, name=None):
×
222
            alignment_dict[table_pattern] = value
×
223
    num_taxa = len(list(alignment_dict.keys())[0])  # Length of first pattern
×
224
    all_splits = list(splits.all_splits(num_taxa)) if splits is None else splits
×
225
    scores = {split: 0 for split in all_splits}
×
226
    for split in all_splits:
×
227
        newick_string = []
×
228
        for part in split.split("|"):
×
229
            newick_string.append(f'({",".join(c for c in part)})')
×
230
        newick_string = f"({newick_string[0]},{newick_string[1]});"
×
231
        split_tree = sp.NXTree(newick_string, taxa_ordering="sorted")
×
232
        for pattern, value in alignment_dict.items():
×
233
            scores[split] += value * (
×
234
                split_tree.hartigan_algorithm(pattern) / (num_taxa - 1)
235
            )
236
    return scores
×
237

238

239
def JC_corrected_distance_matrix(alignment):
2✔
240
    num_taxa = len(list(alignment.keys())[0])  # Length of first pattern
×
241
    taxa = [str(i) for i in range(num_taxa)]
×
242
    distance_matrix = [[0 for _ in range(num_taxa)] for _ in range(num_taxa)]
×
243
    for i in range(num_taxa):
×
244
        for j in range(i + 1, num_taxa):
×
245
            for pattern, value in alignment.items():
×
246
                if pattern[i] != pattern[j]:
×
247
                    distance_matrix[i][j] += value
×
248
                    distance_matrix[j][i] += value
×
249
    for i in range(num_taxa):
×
250
        for j in range(i + 1, num_taxa):
×
251
            distance_matrix[i][j] = (
×
252
                -3.0 / 4.0 * np.log(1 - 4.0 / 3.0 * distance_matrix[i][j])
253
            )
254
            distance_matrix[j][i] = distance_matrix[i][j]
×
255
    return distance_matrix
×
256

257

258
def euclidean_split_distance(alignment, splits):
2✔
259
    print("assuming sorted taxa")
×
260
    states = ("A", "G", "C", "T")
×
261
    alignment_dict = {}
×
262
    for table_pattern, value in alignment.itertuples(index=False, name=None):
×
263
        alignment_dict[table_pattern] = value
×
264
    num_taxa = len(list(alignment_dict.keys())[0])  # Length of first pattern
×
265
    all_splits = list(splits.all_splits(num_taxa)) if splits == None else splits
×
266
    scores = {split: 0 for split in all_splits}
×
267
    for split in all_splits:
×
268
        split_list = split.split("|")
×
269
        for pattern, value in alignment_dict.items():
×
270
            part_a = "".join(pattern[int(s, base=num_taxa + 1)] for s in split_list[0])
×
271
            part_b = "".join(pattern[int(s, base=num_taxa + 1)] for s in split_list[1])
×
272
            vec_a = np.array([part_a.count(state) for state in states])
×
273
            vec_b = np.array([part_b.count(state) for state in states])
×
274
            vec_a = vec_a / np.linalg.norm(vec_a)
×
275
            vec_b = vec_b / np.linalg.norm(vec_b)
×
276
            scores[split] += value * (2 - np.linalg.norm(vec_a - vec_b)) / 2
×
277
    return scores
×
278

279

280
def __dense_split_score(matrix, k=None, singularValues=False, force_frob_norm=False):
2✔
281
    singular_values = list(
×
282
        scipy.linalg.svd(
283
            np.array(matrix), full_matrices=False, check_finite=False, compute_uv=False
284
        )
285
    )
286
    if force_frob_norm:
×
287
        return (
×
288
            1
289
            - (sum(val**2 for val in singular_values[0:4]))
290
            / (matrix.frobrenius_norm(matrix) ** 2)
291
        ) ** (1 / 2)
292
    else:
293
        min_shape = min(matrix.shape)
×
294
        return (
×
295
            1
296
            - (
297
                sum(val**2 for val in singular_values[0:4])
298
                / sum(val**2 for val in singular_values[0:min_shape])
299
            )
300
        ) ** (1 / 2)
301

302

303
def __sparse_split_score(
2✔
304
    matrix, return_singular_values=False, data_table_for_frob_norm=None
305
):
306
    largest_four_singular_values = scipy.sparse.linalg.svds(
×
307
        matrix, 4, return_singular_vectors=False
308
    )
309
    squared_singular_values = [sigma**2 for sigma in largest_four_singular_values]
×
310
    norm = frobenius_norm(matrix, data_table=data_table_for_frob_norm)
×
311
    operand = 1 - (sum(squared_singular_values) / (norm**2))
×
312
    return sqrt(operand if operand > 0 else 0)
×
313

314

315
def split_score(
2✔
316
    matrix,
317
    return_singular_values=False,
318
    force_frob_norm_on_dense=False,
319
    data_table_for_frob_norm=None,
320
):
321
    if is_sparse(matrix):
×
322
        return __sparse_split_score(
×
323
            matrix, return_singular_values, data_table_for_frob_norm
324
        )
325
    else:
326
        return __dense_split_score(
×
327
            matrix, return_singular_values, force_frob_norm_on_dense
328
        )
329

330

331
def flattening_rank_1_approximation(
2✔
332
    flattening, return_vectors=False, dont_compute_matrix=False
333
):
334
    r = np.array([sum(flattening)])
×
335
    c = np.array([sum(flattening.T)])
×
336
    approximation = None if dont_compute_matrix else r.T @ c
×
337
    if return_vectors:
×
338
        return approximation, r.tolist()[0], c.tolist()[0]
×
339
    else:
340
        return approximation
×
341

342

343
def flattening_rank_k_approximation(split, alignment):
2✔
344
    taxa = sorted(set(split[0]) | set(split[1]))
×
345
    sums_of_rows = [
×
346
        sum(
347
            sp.constructions.sparse_flattening_with_banned_patterns(
348
                split, alignment, taxa, ban_row_patterns=char
349
            )
350
        )
351
        for char in constants.DNA_state_space
352
    ]
353
    sums_of_cols = [
×
354
        sum(
355
            sp.constructions.sparse_flattening_with_banned_patterns(
356
                split, alignment, taxa, ban_col_patterns=char
357
            ).T
358
        )
359
        for char in constants.DNA_state_space
360
    ]
361
    return sum(A.T * B for A, B in zip(sums_of_rows, sums_of_cols))
×
362

363

364
def flattening_rank_1_approximation_divergence(flattening):
2✔
365
    _, r, c = flattening_rank_1_approximation(
×
366
        flattening, return_vectors=True, dont_compute_matrix=True
367
    )
368
    total = 0
×
369
    for x in range(len(c)):
×
370
        for y in range(len(r)):
×
371
            if flattening[x, y] != 0:
×
372
                total += flattening[x, y] * np.log(flattening[x, y] / (r[y] * c[x]))
×
373
    return total
×
374

375

376
def star_tree(num_leaves):
2✔
377
    root_index = -1
×
378
    G = nx.DiGraph()
×
379
    G.add_node(root_index)
×
380
    # add the leaves as nodes and edges to the central node
381
    for i in range(0, num_leaves):
×
382
        G.add_node(i)
×
383
        G.add_edge(root_index, i)
×
384
    return G
×
385

386

387
def join_nodes(T, i, j, new_node, root_index):
2✔
388
    # Add new node
389
    T.add_node(new_node)
×
390
    T.add_edge(root_index, new_node)
×
391
    # Join nodes
392
    T.add_edge(new_node, i)
×
393
    T.add_edge(new_node, j)
×
394
    # Remove old edges
395
    T.remove_edge(root_index, i)
×
396
    T.remove_edge(root_index, j)
×
397

398

399
def neighbour_joining(distance_matrix, labels=None, return_newick=False):
2✔
400

401
    # Initialise
402
    D = distance_matrix.copy().astype(float)
×
403
    n = D.shape[0]
×
404
    new_node = n
×
405
    root_index = -1
×
406
    ignore = {-1}
×
407
    T = star_tree(n)
×
408
    num_leaves = n
×
409
    if labels is not None:
×
410
        # Add a label to each node
411
        for i in range(n):
×
412
            T.nodes[i]["label"] = labels[i]
×
413

414
    # NJ Algorithm
415
    while num_leaves > 2:
×
416
        # Instantiate Q matrix
417
        Q = np.full((n, n), np.inf)
×
418
        for i in range(n):
×
419
            for j in range(n):
×
420
                if i != j and i not in ignore and j not in ignore:
×
421
                    Q[i, j] = (num_leaves - 2) * D[i, j] - sum(D[i, :]) - sum(D[:, j])
×
422

423
        # Get the smallest value in Q
424
        min_value = np.min(Q)
×
425
        # get all the incidences of the smallest value
426
        min_indices = np.where(Q == min_value)
×
427
        # transform into tuples
428
        min_pairs = list(zip(*min_indices))
×
429
        # Choose one at random
430
        i, j = min_pairs[np.random.randint(len(min_pairs))]
×
431

432
        #i, j = np.unravel_index(Q.argmin(), Q.shape)
433
        ignore.add(i)
×
434
        ignore.add(j)
×
435

436
        # Join new nodes
437
        join_nodes(T, i, j, new_node, root_index)
×
438

439
        # Estimate branch lengths
440
        dist_new_to_i = (1 / 2) * D[i, j] + (1 / (2 * (num_leaves - 2))) * (
×
441
            sum(D[i, :]) - sum(D[j, :])
442
        )
443
        T.edges[new_node, i]["weight"] = dist_new_to_i
×
444
        T.edges[new_node, j]["weight"] = D[i, j] - dist_new_to_i
×
445

446
        # Append new row and column to distance matrix
447
        D = np.append(D, np.zeros((1, D.shape[0])), axis=0)
×
448
        D = np.append(D, np.zeros((D.shape[0], 1)), axis=1)
×
449
        n = D.shape[0]
×
450

451
        # Compute distance from other leaves to new node
452
        for k in range(n - 1):
×
453
            if k not in ignore:
×
454
                D[k, new_node] = (1 / 2) * (D[i, k] + D[j, k] - D[i, j])
×
455
                D[new_node, k] = D[k, new_node]
×
456

457
        # 'Delete' i and j by setting row and column to an array of zero for i and j
458
        D[i, :] = np.zeros(n)
×
459
        D[:, i] = np.zeros(n)
×
460
        D[j, :] = np.zeros(n)
×
461
        D[:, j] = np.zeros(n)
×
462

463
        new_node += 1
×
464
        num_leaves -= 1
×
465

466
    # Join the last two nodes
467
    i = max(set(range(n)) - ignore)
×
468
    j = max(set(range(n)) - ignore - {i})
×
469
    # Remove the root node and connected edges
470
    T.remove_node(root_index)
×
471
    # Join nodes
472
    T.add_edge(i, j)
×
473
    # Add branch length
474
    T.edges[i, j]["weight"] = D[i, j]
×
475

476
    return T
×
477

478

479
def distance_matrix(networkx_tree):
2✔
480
    """Distance matrix of a tree.
481

482
    Args:
483
        networkx_tree (networkx.DiGraph): A tree.
484

485
    Returns:
486
        numpy.ndarray: A distance matrix.
487
    """
488
    # Get all the leaves
489
    leaf_nodes = [
×
490
        node for node in networkx_tree.nodes if networkx_tree.out_degree(node) == 0
491
    ]
492
    # Create the distance matrix
493
    distance_matrix = np.zeros((len(leaf_nodes), len(leaf_nodes)))
×
494
    for i in range(len(leaf_nodes)):
×
495
        for j in range(i + 1, len(leaf_nodes)):
×
496
            distance = nx.shortest_path_length(
×
497
                networkx_tree.to_undirected(),
498
                leaf_nodes[i],
499
                leaf_nodes[j],
500
                weight="weight",
501
            )
502
            distance_matrix[i, j] = distance
×
503
            distance_matrix[j, i] = distance
×
504
    return distance_matrix
×
505

506

507
def midpoint_rooting(networkx_tree, weight_label="weight"):
2✔
508
    """Midpoint rooting of a tree.
509

510
    Args:
511
        networkx_tree (networkx.DiGraph): A tree.
512

513
    Returns:
514
        networkx.DiGraph: A rooted tree.
515
    """
516
    # Get all the leaves
517
    leaf_nodes = [
×
518
        node for node in networkx_tree.nodes if networkx_tree.out_degree(node) == 0
519
    ]
520
    # Get the distance matrix
521
    D = distance_matrix(networkx_tree)
×
522
    # Get the index of the largest distance
523
    max_dist = np.max(D)
×
524
    i, j = np.unravel_index(np.argmax(D, axis=None), D.shape)
×
525
    # Get the undirected version of the tree
526
    tree_undirected = networkx_tree.to_undirected()
×
527
    # Get the path between the two leaves
528
    path = nx.shortest_path(tree_undirected, leaf_nodes[i], leaf_nodes[j])
×
529
    midpoint_dist = max_dist / 2
×
530
    # Travel along the path until the midpoint is reached. Then go back and add a new node
531
    current_dist = 0
×
532
    prev_dist = 0
×
533
    print(path)
×
534
    for k in range(len(path) - 1):
×
535
        prev_dist = current_dist
×
536
        current_dist += tree_undirected[path[k]][path[k + 1]][weight_label]
×
537
        if current_dist >= midpoint_dist:
×
538
            # Add a new node
539
            new_node = -1
×
540
            networkx_tree.add_node(new_node)
×
541
            # Add the edges
542
            networkx_tree.add_edge(new_node, path[k])
×
543
            networkx_tree.add_edge(new_node, path[k + 1])
×
544
            # Remove the old edges
545
            if (path[k], path[k + 1]) in networkx_tree.edges:
×
546
                networkx_tree.remove_edge(path[k], path[k + 1])
×
547
            elif (path[k + 1], path[k]) in networkx_tree.edges:
×
548
                networkx_tree.remove_edge(path[k + 1], path[k])
×
549
            else:
550
                raise ValueError("Edge not found. Is tree already rooted?")
×
551
            # Add the branch lengths
552
            networkx_tree.edges[new_node, path[k]][weight_label] = (
×
553
                current_dist - midpoint_dist
554
            )
555
            networkx_tree.edges[new_node, path[k + 1]][weight_label] = (
×
556
                midpoint_dist - prev_dist
557
            )
558
            break
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc