• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

js51 / SplitP / 7987035156

21 Feb 2024 09:47AM UTC coverage: 46.646% (+0.4%) from 46.211%
7987035156

push

github

js51
add taxa ordering and an Alignment class for keeping track of it

21 of 45 new or added lines in 7 files covered. (46.67%)

6 existing lines in 3 files now uncovered.

445 of 954 relevant lines covered (46.65%)

1.4 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.29
/splitp/constructions.py
1
import numpy as np
3✔
2
from splitp.enums import FlatFormat
3✔
3
from scipy.sparse import dok_matrix, coo_matrix
3✔
4
import splitp.constants as constants
3✔
5

6

7
def flattening(split, pattern_probabilities, flattening_format=FlatFormat.sparse):
3✔
8
    """
9
    Compute the flattening of a split given a pattern probability dictionary.
10

11
    Args:
12
        split (str or list): The split to compute the flattening of.
13
        pattern_probabilities (dict): A dictionary of pattern probabilities.
14
        flattening_format (FlatFormat): The format to return the flattening in.
15

16
    Returns:
17
        The flattening of the split in the specified format.
18
    """
19
    if isinstance(split, str):
3✔
20
        split = split.split("|")
×
21
    try:
3✔
22
        taxa = pattern_probabilities.taxa
3✔
23
    except AttributeError:
3✔
24
        taxa = sorted(set.union(*map(set, split)))
3✔
25
    if flattening_format is FlatFormat.sparse:
3✔
26
        return __sparse_flattening(split, pattern_probabilities, taxa)
3✔
27
    if flattening_format is FlatFormat.reduced:
3✔
28
        return __reduced_flattening(split, pattern_probabilities, taxa)
3✔
29

30

31
def __reduced_flattening(split, pattern_probabilities, taxa):
3✔
32
    if isinstance(split, str):
3✔
33
        split = split.split("|")
×
34
    flattening_data = {}
3✔
35
    used_cols = set()
3✔
36
    taxa_indexer = {taxon: i for i, taxon in enumerate(taxa)}
3✔
37
    for r in pattern_probabilities.items():
3✔
38
        pattern = r[0]
3✔
39
        row = __index_of("".join([str(pattern[taxa_indexer[s]]) for s in split[0]]))
3✔
40
        col = __index_of("".join([str(pattern[taxa_indexer[s]]) for s in split[1]]))
3✔
41
        used_cols.add(col)
3✔
42
        try:
3✔
43
            flattening_data[row][col] = r[1]
3✔
44
        except KeyError:
3✔
45
            flattening_data[row] = {col: r[1]}
3✔
46
    column_sort_order = {}
3✔
47

48
    for i, used_col in enumerate(sorted(used_cols)):
3✔
49
        column_sort_order[used_col] = i
3✔
50

51
    flattening = np.zeros((len(flattening_data), len(used_cols)))
3✔
52
    for i, (row_index, column_data) in enumerate(sorted(flattening_data.items())):
3✔
53
        for col_index, prob in column_data.items():
3✔
54
            flattening[i, column_sort_order[col_index]] = prob
3✔
55
    return flattening
3✔
56

57

58
def __sparse_flattening(
3✔
59
    split, pattern_probabilities, taxa, ban_row_patterns=None, ban_col_patterns=None
60
):
61
    format = "dok"  # Temporary hard-coded choice
3✔
62
    if format != "dok":
3✔
63
        raise NotImplementedError("Only dok format is currently supported")
×
64
    if isinstance(split, str):
3✔
65
        split = split.split("|")
×
66
    taxa_indexer = {taxon: i for i, taxon in enumerate(taxa)}
3✔
67
    if format == "coo":
3✔
68
        rows = []
×
69
        cols = []
×
70
        data = []
×
71
        for r in pattern_probabilities.items():
×
72
            if r[1] != 0:
×
73
                pattern = r[0]
×
74
                row = __index_of(
×
75
                    "".join([str(pattern[taxa_indexer[s]]) for s in split[0]])
76
                )
77
                col = __index_of(
×
78
                    "".join([str(pattern[taxa_indexer[s]]) for s in split[1]])
79
                )
80
                rows.append(row)
×
81
                cols.append(col)
×
82
                data.append(r[1])
×
83
        return coo_matrix(
×
84
            (data, (rows, cols)), shape=(4 ** len(split[0]), 4 ** len(split[1]))
85
        )
86
    elif format == "dok":
3✔
87
        flattening = dok_matrix((4 ** len(split[0]), 4 ** len(split[1])))
3✔
88
        for r in pattern_probabilities.items():
3✔
89
            pattern = r[0]
3✔
90
            row_pattern = "".join([str(pattern[taxa_indexer[s]]) for s in split[0]])
3✔
91
            row = __index_of(row_pattern)
3✔
92
            col_pattern = "".join([str(pattern[taxa_indexer[s]]) for s in split[1]])
3✔
93
            col = __index_of(col_pattern)
3✔
94
            if (
3✔
95
                ban_col_patterns is not None and col_pattern.count(ban_col_patterns) > 1
96
            ) or (
97
                ban_row_patterns is not None and row_pattern.count(ban_row_patterns) > 1
98
            ):
NEW
99
                flattening[row, col] = 0
×
100
            else:
101
                flattening[row, col] = r[1]
3✔
102
        return flattening
3✔
103

104

105
sparse_flattening_with_banned_patterns = __sparse_flattening
3✔
106

107

108
def subflattening(split, pattern_probabilities, data=None):
3✔
109
    """
110
    A faster version of signed sum subflattening. Requires a data dictionary and can be supplied with a bundle of
111
    re-usable information to reduce the number of calls to the multiplications function.
112
    """
113
    state_space = constants.DNA_state_space
3✔
114
    try:
3✔
115
        taxa = pattern_probabilities.taxa
3✔
116
    except AttributeError:
3✔
117
        taxa = sorted(set.union(*map(set, split)))
3✔
118
    taxa_indexer = {taxon: i for i, taxon in enumerate(taxa)}
3✔
119

120
    if data is None:
3✔
121
        data = {}
3✔
122
    try:
3✔
123
        coeffs = data["coeffs"]
3✔
124
        labels = data["labels"]
×
125
    except KeyError:
3✔
126
        data["coeffs"] = coeffs = {}
3✔
127
        data["labels"] = labels = {}
3✔
128

129
    if isinstance(split, str):
3✔
130
        split = split.split("|")
×
131
    sp1, sp2 = map(len, split)
3✔
132
    subflattening = [[0 for _ in range(3 * sp2 + 1)] for _ in range(3 * sp1 + 1)]
3✔
133
    try:
3✔
134
        row_labels = labels[sp1]
3✔
135
    except KeyError:
3✔
136
        row_labels = list(__subflattening_labels_generator(sp1))
3✔
137
        labels[sp1] = row_labels
3✔
138
    try:
3✔
139
        col_labels = labels[sp2]
3✔
140
    except KeyError:
×
141
        col_labels = list(__subflattening_labels_generator(sp2))
×
142
        labels[sp2] = col_labels
×
143
    banned = (
3✔
144
        {("C", "C"), ("G", "G"), ("A", "T")}
145
        | {(x, "A") for x in state_space}
146
        | {("T", x) for x in state_space}
147
    )
148
    for r, row in enumerate(row_labels):
3✔
149
        for c, col in enumerate(col_labels):
3✔
150
            pattern = __reconstruct_pattern(split, row, col, taxa_indexer)
3✔
151
            signed_sum = 0
3✔
152
            for table_pattern, value in pattern_probabilities.items():
3✔
153
                try:
3✔
154
                    product = coeffs[(pattern, table_pattern)]
3✔
155
                except KeyError:
3✔
156
                    product = 1
3✔
157
                    for t in zip(pattern, table_pattern):
3✔
158
                        if t not in banned:
3✔
159
                            product *= -1
3✔
160
                    coeffs[(pattern, table_pattern)] = product
3✔
161
                signed_sum += product * value
3✔
162
            subflattening[r][c] = signed_sum
3✔
163
    return np.array(subflattening)
3✔
164

165

166
def __index_of(string):
3✔
167
    string = reversed(string)
3✔
168
    index = 0
3✔
169
    for o, s in enumerate(string):
3✔
170
        index += (4**o) * constants.DNA_state_space_dict[s]
3✔
171
    return index
3✔
172

173

174
def __subflattening_labels_generator(length):
3✔
175
    n = length
3✔
176
    state_space = constants.DNA_state_space
3✔
177
    other_states = state_space[0:-1]
3✔
178
    special_state = state_space[-1]
3✔
179
    templates = (
3✔
180
        (
181
            "".join("T" for _ in range(i)),
182
            "".join("T" for _ in range(n - i - 1)),
183
        )
184
        for i in range(n)
185
    )
186
    for template in templates:
3✔
187
        for c in other_states:
3✔
188
            yield f"{template[0]}{c}{template[1]}"
3✔
189
    yield "".join(special_state for _ in range(n))
3✔
190

191

192
def __reconstruct_pattern(split, row_label, col_label, taxa_indexer):
3✔
193
    n = len(taxa_indexer)
3✔
194
    pattern = {}
3✔
195
    for split_half, label in zip(split, (row_label, col_label)):
3✔
196
        for split_index, taxon in enumerate(split_half):
3✔
197
            pattern[taxa_indexer[taxon]] = label[split_index]
3✔
198
    return "".join(pattern[i] for i in range(n))
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc